KD-Tree学习笔记
思想:
每个节点对应一个矩形区域。
把一个平面分成两半,作为左右子树。
把查询点的问题转化成二叉树上的查询问题。
维持树的平衡性:
建树时,按照方差较大的维分割,选取中位数分割。
若要动态维护的点,当某个节点过重时,遍历所有子节点重新建树。
P1429 平面最近点对(加强版)
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 10;
int n;
double ans = 2e18;
struct kd_tree {
double l, r, u, d;
int lc, rc;
} t[N];
struct point {
double x, y;
double dis2(const point &p) const {
return (x - p.x) * (x - p.x) + (y - p.y) * (y - p.y);
}
} p[N];
void pushup(int x) {
t[x].l = t[x].r = p[x].x; t[x].u = t[x].d = p[x].y;
auto extend = [](int a, int b) {
t[a].l = min(t[a].l, t[b].l);
t[a].r = max(t[a].r, t[b].r);
t[a].u = max(t[a].u, t[b].u);
t[a].d = min(t[a].d, t[b].d);
};
if (t[x].lc) extend(x, t[x].lc);
if (t[x].rc) extend(x, t[x].rc);
}
int build(int l, int r) {
if (l > r) return 0;
if (l == r) {
t[l].l = t[l].r = p[l].x; t[l].u = t[l].d = p[l].y;
return l;
}
//1.选择方差最大的维度
//2.选择中位数进行分割
int mid = (l + r) >> 1;
double vx = 0, vy = 0, sx = 0, sy = 0;
for (int i = l; i <= r; i++) vx += p[i].x, vy += p[i].y;
vx /= 1.0 * (r - l + 1); vy /= 1.0 * (r - l + 1);
for (int i = l; i <= r; i++)
sx += (p[i].x - vx) * (p[i].x - vx), sy += (p[i].y - vy) * (p[i].y - vy);
if (sx >= sy) nth_element(p + l, p + mid, p + r + 1, [](point a, point b) {return a.x < b.x;} );
else nth_element(p + l, p + mid, p + r + 1, [](point a, point b) {return a.y < b.y;} );
t[mid].lc = build(l, mid - 1); t[mid].rc = build(mid + 1, r);
pushup(mid);
return mid;
}
void query(int l, int r, int x) {
if (l > r) return;
int mid = (l + r) >> 1;
if (mid != x) ans = min(ans, p[mid].dis2(p[x]));
if (l == r) return;
//求点到矩形的最短距离平方
auto f = [](point &p, kd_tree &q) {
double l = q.l, r = q.r, u = q.u, d = q.d, x = p.x, y = p.y;
double res = 0;
if (l > x) res += (l - x) * (l - x);
if (r < x) res += (x - r) * (x - r);
if (d > y) res += (d - y) * (d - y);
if (u < y) res += (y - u) * (y - u);
return res;
};
double disl = f(p[x], t[t[mid].lc]), disr = f(p[x], t[t[mid].rc]);
//启发式查询
if (disl < disr) {
if (disl < ans) query(l, mid - 1, x);
if (disr < ans) query(mid + 1, r, x);
} else {
if (disr < ans) query(mid + 1, r, x);
if (disl < ans) query(l, mid - 1, x);
}
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%lf%lf", &p[i].x, &p[i].y);
build(1, n);
for (int i = 1; i <= n; i++) query(1, n, i);
printf("%.4lf\n", sqrt(ans));
return 0;
}
P4475 巧克力王国
//
// Created by blackbird on 2023/3/16.
//
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 2e5 + 10;
int n, m, a, b, c;
double ans = 2e18;
struct kd_tree {
int l, r, u, d;
int lc, rc;
int sum;
} t[N];
struct point {
int x, y;
int val;
double dis2(const point &p) const {
return (x - p.x) * (x - p.x) + (y - p.y) * (y - p.y);
}
} p[N];
void pushup(int x) {
t[x].l = t[x].r = p[x].x;
t[x].u = t[x].d = p[x].y;
t[x].sum = p[x].val;
auto extend = [](int a, int b) {
t[a].l = min(t[a].l, t[b].l);
t[a].r = max(t[a].r, t[b].r);
t[a].u = max(t[a].u, t[b].u);
t[a].d = min(t[a].d, t[b].d);
t[a].sum += t[b].sum;
};
if (t[x].lc) extend(x, t[x].lc);
if (t[x].rc) extend(x, t[x].rc);
}
int build(int l, int r) {
if (l > r) return 0;
if (l == r) {
t[l].l = t[l].r = p[l].x;
t[l].u = t[l].d = p[l].y;
t[l].sum = p[l].val;
return l;
}
//1.选择方差最大的维度
//2.选择中位数进行分割
int mid = (l + r) >> 1;
double vx = 0, vy = 0, sx = 0, sy = 0;
for (int i = l; i <= r; i++) vx += p[i].x, vy += p[i].y;
vx /= 1.0 * (r - l + 1); vy /= 1.0 * (r - l + 1);
for (int i = l; i <= r; i++)
sx += (p[i].x - vx) * (p[i].x - vx), sy += (p[i].y - vy) * (p[i].y - vy);
if (sx >= sy) nth_element(p + l, p + mid, p + r + 1, [](point a, point b) {return a.x < b.x;} );
else nth_element(p + l, p + mid, p + r + 1, [](point a, point b) {return a.y < b.y;} );
t[mid].lc = build(l, mid - 1); t[mid].rc = build(mid + 1, r);
pushup(mid);
return mid;
}
int query(int u) {
auto check = [](int x, int y) { return a * x + b * y < c; };
int tmp = check(t[u].l, t[u].d) + check(t[u].l, t[u].u) + check(t[u].r, t[u].d) + check(t[u].r, t[u].u);
if (tmp == 0) return 0;
if (tmp == 4) return t[u].sum;
int res = 0;
if (check(p[u].x, p[u].y)) res += p[u].val;
if (t[u].lc) res += query(t[u].lc);
if (t[u].rc) res += query(t[u].rc);
return res;
}
signed main() {
cin >> n >> m;
for (int i = 1; i <= n; i++)
cin >> p[i].x >> p[i].y >> p[i].val;
int rt = build(1, n);
for (int i = 1; i <= m; i++) {
cin >> a >> b >> c;
cout << query(rt) << "\n";
}
return 0;
}

浙公网安备 33010602011771号