📚【模板】K-D Tree

\(\text{Loj #2043. 「CQOI2016」K 远点对}\)

#include <stdio.h>
#include <queue>
#include <bits/stl_algo.h>
using namespace std;
const int N = 131026;
long long n, k;
priority_queue<long long,vector<long long>,greater<long long> > q;

struct NODE {
	long long x, y;
} node[N];
bool comp1(NODE _a,NODE _b) {
	return _a.x < _b.x;
}
bool comp2(NODE _a,NODE _b) {
	return _a.y < _b.y;
}

struct TREE {
	long long lid, rid;
	long long L, R, D, U;
} tree[N];
#define lid(id) tree[id].lid
#define rid(id) tree[id].rid
#define L(id) tree[id].L
#define R(id) tree[id].R
#define U(id) tree[id].U
#define D(id) tree[id].D

template<typename Tec>
Tec sq (Tec _a) {
	return _a*_a;
}
long long dist(int _a,int _b) {
	return	max(sq(node[_a].x-L(_b)),sq(node[_a].x-R(_b)))+
			max(sq(node[_a].y-D(_b)),sq(node[_a].y-U(_b)));
}

void update(int id) {
	L(id) = R(id) = node[id].x;
	U(id) = D(id) = node[id].y;
	if(lid(id)) {
		L(id) = min(L(id),L(lid(id)));
		R(id) = max(R(id),R(lid(id)));
		D(id) = min(D(id),D(lid(id)));
		U(id) = max(U(id),U(lid(id)));
	} 
	if(rid(id)) {
		L(id) = min(L(id),L(rid(id)));
		R(id) = max(R(id),R(rid(id)));
		D(id) = min(D(id),D(rid(id)));
		U(id) = max(U(id),U(rid(id)));
	}
}

int construct(int L,int R) {
	if(L > R) 
		return 0;
	int mid = (L+R)>>1;
	double av1 = 0, av2 = 0, va1 = 0, va2 = 0;
	for(int i = L;i <= R;++i) {
		av1 += node[i].x;
		av2 += node[i].y;
	}
	av1 /= (R-L+1);
	av2 /= (R-L+1);
	for(int i = L;i <= R;++i) {
		va1 += (av1-node[i].x)*(av1-node[i].x);
		va2 += (av2-node[i].y)*(av2-node[i].y);
	}
	if(va1 > va2) 
		nth_element(node+L,node+mid,node+R+1,comp1);
	else 
		nth_element(node+L,node+mid,node+R+1,comp2);
	lid(mid) = construct(L,mid-1);
	rid(mid) = construct(mid+1,R);
	update(mid);
	return mid;
}

void query(int id,int L,int R) {
	if(L > R) return;
	int mid = (L+R)>>1;
	long long dis = sq(node[mid].x-node[id].x)+sq(node[mid].y-node[id].y);
	if(dis > q.top()) {
		q.pop();
		q.push(dis);
	}
	long long dis1 = dist(id,lid(mid)), dis2 = dist(id,rid(mid));
	if(dis1 > q.top()&&dis2 > q.top()) {
		if(dis1 > dis2) {
			query(id,L,mid-1);
			if(dis2 > q.top()) 
				query(id,mid+1,R);
		} else {
			query(id,mid+1,R);
			if(dis1 > q.top()) 
				query(id,L,mid-1);
		}
	} else {
		if(dis1 > q.top()) 
			query(id,L,mid-1);
		if(dis2 > q.top()) 
			query(id,mid+1,R);
	}
}

signed main() {
	scanf("%lld %lld",&n,&k);
	k <<= 1;
	for(int i = 1;i <= k;++i) 
		q.push(0);
	for(int i = 1;i <= n;++i) 
		scanf("%lld %lld",&node[i].x,&node[i].y);
	construct(1,n);
	for(int i = 1;i <= n;++i) 
		query(i,1,n);
	printf("%lld\n",q.top());
}


\(\text{luogu P4169 [Violet]天使玩偶/SJY摆棋子}\)

啊,是待插K-D Tree。

  1. 我们的K-D Tree
struct NODE {
	int x, y;
} node[N], temp;
struct TREE {
	int lid, rid;
	int U, D, L, R;
	int size;
	NODE pt;
} tree[N];
int leaf_tot, root;
double alpha = 0.75;
#define lid(id) tree[id].lid
#define rid(id) tree[id].rid
#define L(id) tree[id].L
#define R(id) tree[id].R
#define U(id) tree[id].U
#define D(id) tree[id].D
#define pn(id) tree[id].pt
#define px(id) tree[id].pt.x
#define py(id) tree[id].pt.y
#define pq(id,ch) (ch ? tree[id].pt.x : tree[id].pt.y)
#define size(id) tree[id].size
int stk[N], top;
  • 之前我们不用再每个TREE节点存NODE,是因为顺序不会被打乱,可现在要插入。
  1. 如何建树?
int construct(int L,int R,int choice) {
	if(L > R) 
		return 0;
	int id = newleaf();
	int mid = (L+R)>>1;
	out_choice = choice;
	std :: nth_element(node+L,node+mid,node+R+1);
	pn(id) = node[mid];
	lid(id) = construct(L,mid-1,choice^1);
	rid(id) = construct(mid+1,R,choice^1);
	update(id);
	return id;
}
  • 正常小建一下就好了。
  1. 上传、建新节点、比较
void update(int id) {
	U(id) = D(id) = py(id);
	L(id) = R(id) = px(id);
	if(lid(id)) {
		L(id) = std :: min(L(id),L(lid(id)));
		R(id) = std :: max(R(id),R(lid(id)));
		D(id) = std :: min(D(id),D(lid(id)));
		U(id) = std :: max(U(id),U(lid(id)));
	}
	if(rid(id)) {
		L(id) = std :: min(L(id),L(rid(id)));
		R(id) = std :: max(R(id),R(rid(id)));
		D(id) = std :: min(D(id),D(rid(id)));
		U(id) = std :: max(U(id),U(rid(id)));
	}
	size(id) = size(lid(id))+size(rid(id))+1;
}
int stk[N], top;
int newleaf() {
	if(top) 
		return stk[top--];
	return ++leaf_tot;
}
  • 加了回收栈。
int out_choice;
bool operator < (const NODE &_a,const NODE &_b) {
	return out_choice ? _a.x < _b.x : _a.y < _b.y;
}
  • 毕竟得确定一下按哪一维排序。
  1. 插入
void insert(const NODE &in,int &id,int choice) {
	if(!id) {
		id = newleaf();
		pn(id) = in;
		lid(id) = rid(id) = 0;
		update(id);
		return;
	}
	if(pq(id,choice) < (choice ? in.x : in.y)) 
		insert(in,rid(id),choice^1);
	else 
		insert(in,lid(id),choice^1);
	update(id);
	check(id,choice);
}
  • 很显然我们找到相应的位置就插入就好了,但是很明显这里会有一个问题:不断地插入节点会导致树的失衡(\(\text{你}\log n\text{没了}\))。所以,想想我们在替罪羊树时的处理吧,检查是否失衡,重建一下就好了。
  1. 检查失衡、推平重建
int construct(int L,int R,int choice) {
	if(L > R) 
		return 0;
	int id = newleaf();
	int mid = (L+R)>>1;
	out_choice = choice;
	std :: nth_element(node+L,node+mid,node+R+1);
	pn(id) = node[mid];
	lid(id) = construct(L,mid-1,choice^1);
	rid(id) = construct(mid+1,R,choice^1);
	update(id);
	return id;
}
void refold(int id,int L) {
	if(lid(id)) 
		refold(lid(id),L);
	node[L+size(lid(id))+1] = pn(id);
	stk[++top] = id;
	if(rid(id)) 
		refold(rid(id),L+size(lid(id))+1);
}
void check(int &id,int choice) {
	if(alpha*size(id) < size(lid(id))||alpha*size(id) < size(rid(id))) {
		refold(id,0);
		id = construct(1,size(id),choice);
	}
}
  • 重建与原来建树是一个函数
  1. 查询
int disps(const NODE &in,int id) {
	int res = 0;
	res += std :: max(0,in.x-R(id))+std :: max(0,L(id)-in.x);
	res += std :: max(0,in.y-U(id))+std :: max(0,D(id)-in.y);
	return res;
}
int dispp(const NODE &_a,const NODE &_b) {
	return abs(_a.x-_b.x)+abs(_a.y-_b.y);
}
int query_ans;
void query(const NODE &in,int id) {
	query_ans = std :: min(query_ans,dispp(in,pn(id)));
	int dis1 = 998244353;
	int dis2 = 998244353;
	if(lid(id)) 
		dis1 = disps(in,lid(id));
	if(rid(id)) 
		dis2 = disps(in,rid(id));
	if(dis1 < dis2) {
		if(dis1 < query_ans) 
			query(in,lid(id));
		if(dis2 < query_ans) 
			query(in,rid(id)); 
	} else {
		if(dis2 < query_ans) 
			query(in,rid(id)); 
		if(dis1 < query_ans) 
			query(in,lid(id));
	}
}
  • 还是原来的查询函数,尽量少搜节点。
code here
#include <stdio.h>
#include <bits/stl_algobase.h>
#include <bits/stl_algo.h>
const int N = 1048576;
struct NODE {
	int x, y;
} node[N], temp;
struct TREE {
	int lid, rid;
	int U, D, L, R;
	int size;
	NODE pt;
} tree[N];
int leaf_tot, root;
double alpha = 0.75;
#define lid(id) tree[id].lid
#define rid(id) tree[id].rid
#define L(id) tree[id].L
#define R(id) tree[id].R
#define U(id) tree[id].U
#define D(id) tree[id].D
#define pn(id) tree[id].pt
#define px(id) tree[id].pt.x
#define py(id) tree[id].pt.y
#define pq(id,ch) (ch ? tree[id].pt.x : tree[id].pt.y)
#define size(id) tree[id].size
int stk[N], top;
int newleaf() {
	if(top) 
		return stk[top--];
	return ++leaf_tot;
}
int out_choice;
bool operator < (const NODE &_a,const NODE &_b) {
	return out_choice ? _a.x < _b.x : _a.y < _b.y;
}
void update(int id) {
	U(id) = D(id) = py(id);
	L(id) = R(id) = px(id);
	if(lid(id)) {
		L(id) = std :: min(L(id),L(lid(id)));
		R(id) = std :: max(R(id),R(lid(id)));
		D(id) = std :: min(D(id),D(lid(id)));
		U(id) = std :: max(U(id),U(lid(id)));
	}
	if(rid(id)) {
		L(id) = std :: min(L(id),L(rid(id)));
		R(id) = std :: max(R(id),R(rid(id)));
		D(id) = std :: min(D(id),D(rid(id)));
		U(id) = std :: max(U(id),U(rid(id)));
	}
	size(id) = size(lid(id))+size(rid(id))+1;
}
int construct(int L,int R,int choice) {
	if(L > R) 
		return 0;
	int id = newleaf();
	int mid = (L+R)>>1;
	out_choice = choice;
	std :: nth_element(node+L,node+mid,node+R+1);
	pn(id) = node[mid];
	lid(id) = construct(L,mid-1,choice^1);
	rid(id) = construct(mid+1,R,choice^1);
	update(id);
	return id;
}
void refold(int id,int L) {
	if(lid(id)) 
		refold(lid(id),L);
	node[L+size(lid(id))+1] = pn(id);
	stk[++top] = id;
	if(rid(id)) 
		refold(rid(id),L+size(lid(id))+1);
}
void check(int &id,int choice) {
	if(alpha*size(id) < size(lid(id))||alpha*size(id) < size(rid(id))) {
		refold(id,0);
		id = construct(1,size(id),choice);
	}
}
void insert(const NODE &in,int &id,int choice) {
	if(!id) {
		id = newleaf();
		pn(id) = in;
		lid(id) = rid(id) = 0;
		update(id);
		return;
	}
	if(pq(id,choice) < (choice ? in.x : in.y)) 
		insert(in,rid(id),choice^1);
	else 
		insert(in,lid(id),choice^1);
	update(id);
	check(id,choice);
}
int disps(const NODE &in,int id) {
	int res = 0;
	res += std :: max(0,in.x-R(id))+std :: max(0,L(id)-in.x);
	res += std :: max(0,in.y-U(id))+std :: max(0,D(id)-in.y);
	return res;
}
int dispp(const NODE &_a,const NODE &_b) {
	return abs(_a.x-_b.x)+abs(_a.y-_b.y);
}
int query_ans;
void query(const NODE &in,int id) {
	query_ans = std :: min(query_ans,dispp(in,pn(id)));
	int dis1 = 998244353;
	int dis2 = 998244353;
	if(lid(id)) 
		dis1 = disps(in,lid(id));
	if(rid(id)) 
		dis2 = disps(in,rid(id));
	if(dis1 < dis2) {
		if(dis1 < query_ans) 
			query(in,lid(id));
		if(dis2 < query_ans) 
			query(in,rid(id)); 
	} else {
		if(dis2 < query_ans) 
			query(in,rid(id)); 
		if(dis1 < query_ans) 
			query(in,lid(id));
	}
}
int n, m, opt;
signed main() {
	scanf("%d %d",&n,&m);
	for(int i = 1;i <= n;++i) 
		scanf("%d %d",&node[i].x,&node[i].y);
	root = construct(1,n,0);
	for(int i = 1;i <= m;++i) {
		scanf("%d %d %d",&opt,&temp.x,&temp.y);
		if(opt == 1) 
			insert(temp,root,0);
		else {
			query_ans = 998244353;
			query(temp,root);
			printf("%d\n",query_ans);
		}
	}
}
posted @ 2022-07-27 21:07  bikuhiku  阅读(11)  评论(0编辑  收藏  举报