K-D 树 笔记

用于解决 \(k\) 维区间问题的数据结构。

P4148 简单题

这是一种 2-D 树,因为只有两维。

K-D 树用的是类似平衡树的思路,也就是根节点本身也代表一个节点信息,故选择一个维度,用 nth_element 函数找到这个维度居中的点,把它作为根节点即可。

对于一个正方形,横着切一刀竖着切一刀时间复杂度显然优,但是要是对于一个平放的木棍,横着切完全没必要,所以每次按照方差较大的维度进行分治。

插入点之后,不平衡了怎么办?直接用替罪羊树思路暴力重建。具体地,设定一个 \(\alpha=0.75\),若某个点的儿子的子树大小,已经超过了这个点的子树大小的 \(\alpha\) 倍,就把这整个子树的点记录下来,直接重建。

查询和线段树是一样的,如果几个端点都满足查询要求,就直接返回这个节点的信息,几个端点都不满足,就不搜了,否则就只能继续往下递归。

如果是像例题中这种求和,就只能暴力求了,但是如果是求距离最小值 / 最大值,显然可以先用边界条件看看哪个儿子有可能更优,就先搜他,再用各种爆搜用的贪心剪枝,时间复杂度就是 \(O(?)\) 了。

超长代码预警。

namespace KD {
#define mid (l + r >> 1)
    const double alpha = 0.7;
    struct TREE {
        int ls, rs, x, y, v;
        int x1, y1, x2, y2, s, sz;
        int d;
    } t[N];
    int p[N], np, root, idx;
    void push_up(int rt) {
        t[rt].s = t[t[rt].ls].s + t[t[rt].rs].s + t[rt].v;
        t[rt].sz = t[t[rt].ls].sz + t[t[rt].rs].sz + 1;
        t[rt].x1 = t[rt].x2 = t[rt].x;
        t[rt].y1 = t[rt].y2 = t[rt].y;
        if (t[rt].ls) {
            t[rt].x1 = min(t[rt].x1, t[t[rt].ls].x1);
            t[rt].y1 = min(t[rt].y1, t[t[rt].ls].y1);
            t[rt].x2 = max(t[rt].x2, t[t[rt].ls].x2);
            t[rt].y2 = max(t[rt].y2, t[t[rt].ls].y2);
        }
        if (t[rt].rs) {
            t[rt].x1 = min(t[rt].x1, t[t[rt].rs].x1);
            t[rt].y1 = min(t[rt].y1, t[t[rt].rs].y1);
            t[rt].x2 = max(t[rt].x2, t[t[rt].rs].x2);
            t[rt].y2 = max(t[rt].y2, t[t[rt].rs].y2);
        }
    }
    void getp(int rt) {
        if (!rt)
            return;
        p[++np] = rt;
        getp(t[rt].ls), getp(t[rt].rs);
    }
    bool cmpx(int x, int y) {
        return t[x].x < t[y].x;
    }
    bool cmpy(int x, int y) {
        return t[x].y < t[y].y;
    }
    void build(int &rt, int l, int r) {
        if (l > r) {
            rt = 0;
            return;
        }
        double avx = 0, avy = 0, sx = 0, sy = 0;
        for (int i = l; i <= r; i++)
            avx += t[p[i]].x, avy += t[p[i]].y;
        avx = 1.0 * avx / (r - l + 1), avy = 1.0 * avy / (r - l + 1);
        for (int i = l; i <= r; i++) {
            sx += (t[p[i]].x - avx) * (t[p[i]].x - avx);
            sy += (t[p[i]].y - avy) * (t[p[i]].y - avy);
        }
        if (sx > sy) {
            nth_element(p + l, p + mid, p + r + 1, cmpx);
            rt = p[mid];
            t[rt].d = 1;
        } else {
            nth_element(p + l, p + mid, p + r + 1, cmpy);
            rt = p[mid];
            t[rt].d = 2;
        }
        build(t[rt].ls, l, mid - 1);
        build(t[rt].rs, mid + 1, r);
        push_up(rt);
    }
    void rebuild(int &rt) {
        if (max(t[t[rt].ls].sz, t[t[rt].rs].sz) <= t[rt].sz * alpha)
            return;
        np = 0;
        getp(rt);
        build(rt, 1, np);
    }
    void update(int &rt, int v) {
        if (!rt) {
            rt = v;
            push_up(rt);
            return;
        }
        if (t[rt].d == 1)
            t[v].x <= t[rt].x ? update(t[rt].ls, v) : update(t[rt].rs, v);
        else
            t[v].y <= t[rt].y ? update(t[rt].ls, v) : update(t[rt].rs, v);
        push_up(rt);
        rebuild(rt);
    }
    int query(int rt, int x1, int y1, int x2, int y2) {
        if (!rt or t[rt].x1 > x2 or t[rt].y1 > y2 or t[rt].x2 < x1 or t[rt].y2 < y1)
            return 0;
        if (x1 <= t[rt].x1 and t[rt].x2 <= x2 and y1 <= t[rt].y1 and t[rt].y2 <= y2)
            return t[rt].s;
        int ret = 0;
        if (x1 <= t[rt].x and t[rt].x <= x2 and y1 <= t[rt].y and t[rt].y <= y2)
            ret = t[rt].v;
        return ret + query(t[rt].ls, x1, y1, x2, y2) + query(t[rt].rs, x1, y1, x2, y2);
    }
}
using namespace KD;
posted @ 2025-03-24 20:15  Garbage_fish  阅读(16)  评论(0)    收藏  举报