Loading

kd-tree

基础讲解

\(kd-tree\) 不是很常考。但是NOI考过,省选也考过,也可以干一点暴力骗分的事,有时候有奇效。

一般带修,可插入,允许强制在线(只是麻烦一点),然后查询就稀奇古怪了,比如给个矩形,问矩形里有几个点(也可以树套树),或者给每个点一个权值,问矩形里点权和,再或者把矩形改成某种奇怪的解析式,比如圆、椭圆什么的。也可以拿来优化暴力。

关于 \(kd-tree\) ,我觉得它和线段树很像,都可以打标记,然后还可以扩展到k维,这就是为什么叫\(kd-tree\)

\(kd-tree\) 通常是采用循环分割的方法建树的。什么叫循环分割?每次按照某一维为关键字排序,取出中位数,然后以这个节点为根,把[l,mid-1],[mid+1,r]作为左右子树,进行分割。

int cmp_D;
bool cmp(const node &a,const node &b){return a.d[cmp_d]<b.d[cmp_d];}
int build(int l,int r,int D){
	int mid=l+r>>1;cmp_d=D;
	nth_element(t+l,t+mid,t+r+1,cmp);
	t[mid].mx[0]=t[mid].mn[0]=t[mid].d[0];
	t[mid].mx[1]=t[mid].mn[1]=t[mid].d[1];
	if(l!=mid)t[mid].ls=build(l,mid-1,D%k+1);
	if(r!=mid)t[mid].rs=build(mid+1,r,D%k+1);
	up(mid);return mid;
}

这样可以把空间分得均匀一些。

关于那个 \(nth\_lement\) 据说是均摊 \(O(n)\) 把中位数放到mid的位置,其余数据不保证有序(和sort很像)

然后就是 \(kd-tree\) 的优化了(它也是因为这个才那么快)包围盒

包围盒,就是这个节点底下的子树的最小、最大横坐标、纵坐标。可以在pushup的时候记录。还是看代码。

void pushup(int p) {
	tr[p].sz=1,tr[p].sum=tr[p].tp.val;
	tr[p].mn[0]=tr[p].mx[0]=tr[p].tp.d[0];
	tr[p].mn[1]=tr[p].mx[1]=tr[p].tp.d[1];
	if(tr[p].ls) {
		tr[p].mn[0]=min(tr[p].mn[0],tr[tr[p].ls].mn[0]);
		tr[p].mn[1]=min(tr[p].mn[1],tr[tr[p].ls].mn[1]);
		tr[p].mx[0]=max(tr[p].mx[0],tr[tr[p].ls].mx[0]);
		tr[p].mx[1]=max(tr[p].mx[1],tr[tr[p].ls].mx[1]);
		tr[p].sz+=tr[tr[p].ls].sz,tr[p].sum+=tr[tr[p].ls].sum;
	}
	if(tr[p].rs) {
		tr[p].mn[0]=min(tr[p].mn[0],tr[tr[p].rs].mn[0]);
		tr[p].mn[1]=min(tr[p].mn[1],tr[tr[p].rs].mn[1]);
		tr[p].mx[0]=max(tr[p].mx[0],tr[tr[p].rs].mx[0]);
		tr[p].mx[1]=max(tr[p].mx[1],tr[tr[p].rs].mx[1]);
		tr[p].sz+=tr[tr[p].rs].sz,tr[p].sum+=tr[tr[p].rs].sum;
	}
}

这种都可以灵活记录,主要看题目,但是mn,mx一般都会记下来。

然后说说为什么就快起来了。

有了包围盒,在查询的时候,比如上面要查询每个矩形内有几个点,就可以判断这个点的包围盒是否在矩形外,如果是,那么直接舍掉,因为不可能对答案产生贡献。否则,就像线段树一样进去分左右子树查询。

这样很可能一次少一颗子树,这使得kdt在随机数据下几乎是 \(O(n\log n)\) 的。

实际复杂度应该是 \(O(n^{\frac{k+1}{k}})\)(查询),插入和删除都是 \(O(\log n)\) 的。抱歉,这里之前写错了,误导了一些人过了几乎一年才来修

例1

描述

有一列元素,每一个元素有三个属性:标号、标识符、数值。这些元素按照标号从1n排列,标识符也是1n的一个排列,初始时数值为0。当然我们可以把每个元素看成一个多维数字,那么这列元素就是一个数列。

现在请你维护这个数列,使其能支持以下两种操作:

1.将标号为l~r的所有元素的数值先乘上x,再加上y;

2.将标识符为l~r的所有元素的数值先乘上x,再加上y。

当然你还得回答某些询问:

1.标号为l~r的所有元素的数值的和;

2.标识符为l~r的所有元素的数值的和。

输入

第一行有两个正整数n、m,分别表示数列长度和操作与询问个数的总和。

第二行有n个正整数,表示每个元素的标识符,保证这n个数是1~n的一个排列。

接下来m行,每行的第一个数字为op。若op为0,则表示要进行第一个操作,接下去四个数字表示l,r,x,y;若op为1,则表示要进行第二个操作,接下去四个数字表示l,r,x,y;若op为2,则表示要回答第一个询问,接下去两个数字表示l,r;若op为3,则表示要回答第二个询问,接下去两个数字表示l,r。

输出

包含若干行,每行表示一个询问的答案。由于答案可能很大,只要请你输出答案对536870912取模后的值即可。

可以把 \((i,{p_i})\) 看作2维平面上的点,标识符是横着的矩形,区间是竖着的矩形,然后查询区间和,打个懒惰标记。cnt记录区间内有几个点,val是当前点的权值,sum是当前的区间和,这样懒标就可以下传了。其实这题和线段树很像

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=50005;
const int mod=536870912;
struct node{
	int d[2],mn[2],mx[2],ls,rs;
	int cnt,sum,val;
	int pl,ti;
}t[N];
int n,m,ans,rt,tpl,tti;
int op,l,r,x,y,dir,cmp_d;
void up(int u){
	t[u].cnt=1;
	if(t[u].ls){
		t[u].cnt+=t[t[u].ls].cnt;
		t[u].mx[0]=t[u].mx[0]>t[t[u].ls].mx[0]?t[u].mx[0]:t[t[u].ls].mx[0];
		t[u].mx[1]=t[u].mx[1]>t[t[u].ls].mx[1]?t[u].mx[1]:t[t[u].ls].mx[1];
		t[u].mn[0]=t[u].mn[0]<t[t[u].ls].mn[0]?t[u].mn[0]:t[t[u].ls].mn[0];
		t[u].mn[1]=t[u].mn[1]<t[t[u].ls].mn[1]?t[u].mn[1]:t[t[u].ls].mn[1];
	}
	if(t[u].rs){
		t[u].cnt+=t[t[u].rs].cnt;
		t[u].mx[0]=t[u].mx[0]>t[t[u].rs].mx[0]?t[u].mx[0]:t[t[u].rs].mx[0];
		t[u].mx[1]=t[u].mx[1]>t[t[u].rs].mx[1]?t[u].mx[1]:t[t[u].rs].mx[1];
		t[u].mn[0]=t[u].mn[0]<t[t[u].rs].mn[0]?t[u].mn[0]:t[t[u].rs].mn[0];
		t[u].mn[1]=t[u].mn[1]<t[t[u].rs].mn[1]?t[u].mn[1]:t[t[u].rs].mn[1];
	}
}
bool cmp(const node &a,const node &b){return a.d[cmp_d]<b.d[cmp_d];}
int build(int l,int r,int D){
	int mid=l+r>>1;cmp_d=D;
	nth_element(t+l,t+mid,t+r+1,cmp);
	t[mid].mx[0]=t[mid].mn[0]=t[mid].d[0];
	t[mid].mx[1]=t[mid].mn[1]=t[mid].d[1];
	if(l!=mid)t[mid].ls=build(l,mid-1,D^1);
	if(r!=mid)t[mid].rs=build(mid+1,r,D^1);
	up(mid);return mid;
}
void f(int x,int ti,int pl){
	t[x].val=(t[x].val*ti+pl)%mod;
	t[x].sum=(t[x].sum*ti+pl*t[x].cnt)%mod;
	t[x].ti=(t[x].ti*ti)%mod;
	t[x].pl=(t[x].pl*ti+pl)%mod;
}
void down(int u){
	if(t[u].pl==1&&t[u].ti==0)return;
	if(t[u].ls)f(t[u].ls,t[u].ti,t[u].pl);
	if(t[u].rs)f(t[u].rs,t[u].ti,t[u].pl);
	t[u].ti=1;t[u].pl=0;
}
void update(int u){
	if(t[u].mx[dir]<l||t[u].mn[dir]>r)return;
	if(l<=t[u].mn[dir]&&t[u].mx[dir]<=r){f(u,tti,tpl);return;}
	down(u);
	if(t[u].d[dir]>=l&&t[u].d[dir]<=r)t[u].val=(t[u].val*x+y)%mod;
	if(t[u].ls)update(t[u].ls);
	if(t[u].rs)update(t[u].rs);
	t[u].sum=(t[u].val+t[t[u].ls].sum+t[t[u].rs].sum)%mod;
}
void ask(int u)
{
	if(t[u].mx[dir]<l||t[u].mn[dir]>r)return;
	if(l<=t[u].mn[dir]&&t[u].mx[dir]<=r){ans=(ans+t[u].sum)%mod;return;}
	down(u);
	if(l<=t[u].d[dir]&&t[u].d[dir]<=r)ans=(ans+t[u].val)%mod;
	if(t[u].ls)ask(t[u].ls);
	if(t[u].rs)ask(t[u].rs);
 } 
signed main()
{
	scanf("%lld%lld",&n,&m);
	for(int i=1;i<=n;++i)scanf("%lld",&t[i].d[1]),t[i].d[0]=i;
	rt=build(1,n,0);
	while(m--){
		scanf("%lld",&op);
		if(op==0){
			scanf("%lld%lld%lld%lld",&l,&r,&x,&y);
			dir=0;tpl=y%mod;tti=x%mod;update(rt);
		}
		if(op==1){
			scanf("%lld%lld%lld%lld",&l,&r,&x,&y);
			dir=1;tpl=y%mod;tti=x%mod;update(rt);
		}
		if(op==2){
			scanf("%lld%lld",&l,&r);
			dir=0;ans=0;ask(rt);
			printf("%lld\n",ans);
		}
		if(op==3){
			scanf("%lld%lld",&l,&r);
			dir=1;ans=0;ask(rt);
			printf("%lld\n",ans);
		}
	} 
	return 0;
 } 

接着就是上面说到的强制在线的问题了。插入当然可以暴力把点塞进去,但是多了就会比较麻烦:因为数据可以不断地卡 \(kd-tree\) ,卡成一条链,时间复杂度就爆炸了。

对于那些不强制在线的题,可以直接把所有点先插进去,对于目前在树上和不在树上的点打标记,插入后更新标记就好了。

然而强制在线的话……

就要用到替罪羊树的思想拍平重构。\(\color{black}{\texttt{h}}\color{red}{\texttt{ehezhou}}:\) 所有二叉树形结构都可以替罪羊重构

例2

P4148 简单题

这题主要就是要重构烦,其他没什么难的。

#include<bits/stdc++.h>
using namespace std;
const int N=200005;
int n,lastans,root;
struct point {
	int d[2],val;
} a[N];
struct node {
	int sum,ls,rs,mx[2],mn[2],sz;
	point tp;
} tr[N];
int tot,top,rub[N];
int min(const int &a,const int &b) {
	return a<b?a:b;
}
int max(const int &a,const int &b) {
	return a>b?a:b;
}
int newnode() {
	return top?rub[top--]:++tot;
}
void pushup(int p) {
	tr[p].sz=1,tr[p].sum=tr[p].tp.val;
	tr[p].mn[0]=tr[p].mx[0]=tr[p].tp.d[0];
	tr[p].mn[1]=tr[p].mx[1]=tr[p].tp.d[1];
	if(tr[p].ls) {
		tr[p].mn[0]=min(tr[p].mn[0],tr[tr[p].ls].mn[0]);
		tr[p].mn[1]=min(tr[p].mn[1],tr[tr[p].ls].mn[1]);
		tr[p].mx[0]=max(tr[p].mx[0],tr[tr[p].ls].mx[0]);
		tr[p].mx[1]=max(tr[p].mx[1],tr[tr[p].ls].mx[1]);
		tr[p].sz+=tr[tr[p].ls].sz,tr[p].sum+=tr[tr[p].ls].sum;
	}
	if(tr[p].rs) {
		tr[p].mn[0]=min(tr[p].mn[0],tr[tr[p].rs].mn[0]);
		tr[p].mn[1]=min(tr[p].mn[1],tr[tr[p].rs].mn[1]);
		tr[p].mx[0]=max(tr[p].mx[0],tr[tr[p].rs].mx[0]);
		tr[p].mx[1]=max(tr[p].mx[1],tr[tr[p].rs].mx[1]);
		tr[p].sz+=tr[tr[p].rs].sz,tr[p].sum+=tr[tr[p].rs].sum;
	}
}
int cmp_D;
bool cmp(const point &a,const point &b) {
	return a.d[cmp_D]<b.d[cmp_D];
}
int build(int l,int r,int D) {
	if(l>r)return 0;
	int mid=(l+r)>>1,p=newnode();
	cmp_D=D;
	nth_element(a+l,a+mid+1,a+r+1,cmp);
	tr[p].tp=a[mid];
	tr[p].ls=build(l,mid-1,D^1);
	tr[p].rs=build(mid+1,r,D^1);
	pushup(p);
	return p;
}
void beat(int p,int num) {//拍平
	if(tr[p].ls)beat(tr[p].ls,num);
	a[tr[tr[p].ls].sz+num+1]=tr[p].tp,rub[++top]=p;
	if(tr[p].rs)beat(tr[p].rs,tr[tr[p].ls].sz+num+1);
}
void check(int &p,int D) {//重构
	if(tr[p].sz*0.75<tr[tr[p].ls].sz||tr[p].sz*0.75<tr[tr[p].rs].sz)
		beat(p,0),p=build(1,tr[p].sz,D);
}
void ins(int &p,point o,int D) {
	if(!p) {
		p=newnode();
		tr[p].ls=tr[p].rs=0;
		tr[p].tp=o;
		pushup(p);
		return;
	}
	if(o.d[D]<=tr[p].tp.d[D])ins(tr[p].ls,o,D^1);
	else ins(tr[p].rs,o,D^1);
	pushup(p);
	check(p,D);
}
bool in(int x1,int y1,int x2,int y2,int X1,int Y1,int X2,int Y2) {
	return x1<=X1&&X2<=x2&&y1<=Y1&&Y2<=y2;
}
bool out(int x1,int y1,int x2,int y2,int X1,int Y1,int X2,int Y2) {
	return x1>X2||x2<X1||y1>Y2||y2<Y1;
}
int query(int p,int x1,int y1,int x2,int y2) {
	if(!p)return 0;
	int res=0;
	if(in(x1,y1,x2,y2,tr[p].mn[0],tr[p].mn[1],tr[p].mx[0],tr[p].mx[1]))return tr[p].sum;
	if(out(x1,y1,x2,y2,tr[p].mn[0],tr[p].mn[1],tr[p].mx[0],tr[p].mx[1]))return 0;
	if(in(x1,y1,x2,y2,tr[p].tp.d[0],tr[p].tp.d[1],tr[p].tp.d[0],tr[p].tp.d[1]))res+=tr[p].tp.val;
	res+=query(tr[p].ls,x1,y1,x2,y2)+query(tr[p].rs,x1,y1,x2,y2);
	return res;
}
int main() {
	scanf("%d",&n);
	int opt,X1,Y1,X2,Y2,A;
	while("fyy AK IOI") {//据说在这里字符串与"1"的效果是一样的。
		scanf("%d",&opt);
		if(opt==3)return 0;
		if(opt==1) {
			scanf("%d%d%d",&X1,&X2,&A);
			X1^=lastans,X2^=lastans,A^=lastans;
			ins(root,point {X1,X2,A},0);
		}
		if(opt==2) {
			scanf("%d%d%d%d",&X1,&Y1,&X2,&Y2);
			X1^=lastans,Y1^=lastans,X2^=lastans,Y2^=lastans;
			printf("%d\n",lastans=query(root,X1,Y1,X2,Y2));
		}
	}
}

例3

P4357 [CQOI2016]K远点对

这题是 \(kd-tree\) 的经典应用:优化暴力。

发现只需要记录前k大的值,而且 \(1\le k \le 100\) (如果没有这句话可能会被卡成 \(O(n^2\log n)\) ,但是现在不可能qwq)。考虑把所有点插进 \(kd-tree\) ,然后对于每个点做一次查询,每次查询时先更新答案,再判断左右子树是否可能更新答案。这里为什么是可能呢?因为有一个“估价函数”,是通过包围盒实现的。就是当前点到包围盒的顶点距离能否更新答案,如果这都不可以,就不用遍历这棵子树了,因为一定不可能更新答案。

关于第k远,考虑搞一个小根堆,注意把size开到2k,因为每个点会算2遍 ,更新的时候与堆顶比较就行了,注意先弹出再插入,维持住2k的size。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=100010;
int n,k,root;
priority_queue<ll,vector<ll> ,greater<ll> >q;
struct node {
	int ls,rs,mn[2],mx[2],d[2];
} tr[N];
int min(const int &a,const int &b) {
	return a<b?a:b;
}
int max(const int &a,const int &b) {
	return a>b?a:b;
}
int cmp_D;
bool cmp(const node &a,const node &b) {
	return a.d[cmp_D]<b.d[cmp_D];
}
ll sqr(const ll &x) {
	return x*x;
}
ll dis(const node &x,const node &y) {
	return sqr(x.d[0]-y.d[0])+sqr(x.d[1]-y.d[1]);
}
ll mxdis(const node &a,const node &b) {
	return max(sqr(a.d[0]-b.mn[0]),sqr(a.d[0]-b.mx[0]))+max(sqr(a.d[1]-b.mn[1]),sqr(a.d[1]-b.mx[1]));
}
int que;
void pushup(int p) {
	tr[p].mn[0]=tr[p].mx[0]=tr[p].d[0];
	tr[p].mn[1]=tr[p].mx[1]=tr[p].d[1];
	if(tr[p].ls) {
		tr[p].mn[0]=min(tr[p].mn[0],tr[tr[p].ls].mn[0]);
		tr[p].mn[1]=min(tr[p].mn[1],tr[tr[p].ls].mn[1]);
		tr[p].mx[0]=max(tr[p].mx[0],tr[tr[p].ls].mx[0]);
		tr[p].mx[1]=max(tr[p].mx[1],tr[tr[p].ls].mx[1]);
	}
	if(tr[p].rs) {
		tr[p].mn[0]=min(tr[p].mn[0],tr[tr[p].rs].mn[0]);
		tr[p].mn[1]=min(tr[p].mn[1],tr[tr[p].rs].mn[1]);
		tr[p].mx[0]=max(tr[p].mx[0],tr[tr[p].rs].mx[0]);
		tr[p].mx[1]=max(tr[p].mx[1],tr[tr[p].rs].mx[1]);
	}
}
int build(int l,int r,int D) {
	int mid=(l+r)>>1;cmp_D=D;
	nth_element(tr+l+1,tr+mid+1,tr+r+1,cmp);
	if(l!=mid)tr[mid].ls=build(l,mid-1,D^1);
	if(mid!=r)tr[mid].rs=build(mid+1,r,D^1);
	return pushup(mid),mid;
}
void query(int u)
{
	ll mxl=0,mxr=0,d;
	d=dis(tr[que],tr[u]);
	if(d>q.top())q.pop(),q.push(d);
	if(tr[u].ls)mxl=mxdis(tr[que],tr[tr[u].ls]);
	if(tr[u].rs)mxr=mxdis(tr[que],tr[tr[u].rs]);
	if(mxl>mxr) {
		if(mxl>q.top())query(tr[u].ls);
		if(mxr>q.top())query(tr[u].rs);
	}
	else {
		if(mxr>q.top())query(tr[u].rs);
		if(mxl>q.top())query(tr[u].ls);
	}
 } 
signed main() {
	scanf("%d%d",&n,&k);
	for(int i=1; i<=n; ++i)
		scanf("%d%d",&tr[i].d[0],&tr[i].d[1]);
	root=build(1,n,0);
	for(int i=1;i<=(k<<1);++i)q.push(0);
	for(int i=1;i<=n;++i)
	{
		que=i;
		query(root);
	}
	printf("%lld\n",q.top());
	return 0;
}

例4

P3810 【模板】三维偏序(陌上花开)

我们发现kdt本身就支持这个东西。。。只不过它是 \(O(n^{\frac{5}{3}})\) 的,过不去。

那么必须要降维。

考虑按照第三维排序,按照顺序插入。然后就变成二维数点 \(O(n\sqrt n)\) 了!

注意一个细节:第三维可能会相同,所以要先把第三维相同的点都先插入完再查询

这题卡得非常紧,kdt要卡常。。。

#include<bits/stdc++.h>
using namespace std;
#define rint register int
typedef long long LL;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
char buf[1<<21],*p1=buf,*p2=buf;
inline int rd(){
   int x=0,f=1;
   char ch=getchar();
   while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
   while(isdigit(ch)) x=x*10+(ch^48),ch=getchar();
   return x*f;
}
inline int max(const int &x,const int &y) {return x>y?x:y;}
inline int min(const int &x,const int &y) {return x<y?x:y;}
const int N=200010;
int n,k,cnt,cmp_D,tot,ans[N],root;
char cltout[1<<21],*oh=cltout,*ot=cltout+(1<<21);
inline void pc(char c){
    if(oh==ot){
        fwrite(cltout,1,1<<21,stdout);
        oh=cltout;
    }
    *oh++=c;
}
inline void write(int w,char text=-1){
    if(!w)pc(48);
    else{
        int d[10];
        for(d[0]=0;w;d[++d[0]]=w%10,w/=10);
        for(;d[0];pc(d[d[0]--]^48));
    }
    if(text>=0)pc(text);
}
struct node {
	int d[3];
}a[N],b[N];
struct kdt {
	int mx[2],mn[2],d[2],siz,ch[2],tg;
}t[N];
bool cmp1(const node &x,const node &y) {
	return x.d[0]!=y.d[0]?x.d[0]<y.d[0]:x.d[1]<y.d[1];
}
bool cmp2(const node &x,const node &y) {
	return x.d[2]<y.d[2];
}
bool cmp(const node &x,const node &y) {
	return x.d[cmp_D]<y.d[cmp_D];
}
bool in(int x,int y,int X1,int Y1,int X2,int Y2) {
	return X1<=x&&x<=X2&&Y1<=y&&y<=Y2;
}
#define lc(p) t[p].ch[0]
#define rc(p) t[p].ch[1]
void pushup(int p) {
	t[p].siz=t[lc(p)].siz+t[rc(p)].siz+t[p].tg;
	t[p].mn[0]=t[p].mx[0]=t[p].d[0];
	t[p].mn[1]=t[p].mx[1]=t[p].d[1];
	if(lc(p)) 
		t[p].mn[0]=min(t[lc(p)].mn[0],t[p].mn[0]),
		t[p].mn[1]=min(t[lc(p)].mn[1],t[p].mn[1]),
		t[p].mx[0]=max(t[lc(p)].mx[0],t[p].mx[0]),
		t[p].mx[1]=max(t[lc(p)].mx[1],t[p].mx[1]);
	if(rc(p))
		t[p].mn[0]=min(t[rc(p)].mn[0],t[p].mn[0]),
		t[p].mn[1]=min(t[rc(p)].mn[1],t[p].mn[1]),
		t[p].mx[0]=max(t[rc(p)].mx[0],t[p].mx[0]),
		t[p].mx[1]=max(t[rc(p)].mx[1],t[p].mx[1]);
}
void upd(int p) {
	t[p].siz=t[lc(p)].siz+t[rc(p)].siz+t[p].tg;
}
int build(int l,int r,int D) {
	if(l>r)return 0;
	cmp_D=D;int mid=(l+r)>>1;int p=++tot;
	nth_element(b+l,b+mid+1,b+r+1,cmp);
	t[p].d[0]=t[p].mn[0]=t[p].mx[0]=b[mid].d[0];
	t[p].d[1]=t[p].mn[1]=t[p].mx[1]=b[mid].d[1];
	t[p].ch[0]=build(l,mid-1,D^1);
	t[p].ch[1]=build(mid+1,r,D^1);
	pushup(p);return p;
}
void insert(int p,int i) {
	if(a[i].d[0]==t[p].d[0]&&a[i].d[1]==t[p].d[1]) {++t[p].tg,++t[p].siz;return;}
	if(in(a[i].d[0],a[i].d[1],t[lc(p)].mn[0],t[lc(p)].mn[1],t[lc(p)].mx[0],t[lc(p)].mx[1]))insert(lc(p),i);
	if(in(a[i].d[0],a[i].d[1],t[rc(p)].mn[0],t[rc(p)].mn[1],t[rc(p)].mx[0],t[rc(p)].mx[1]))insert(rc(p),i);
	upd(p);
}
int query(int p,int i) {
	if(!p||!t[p].siz)return 0;
	if(a[i].d[0]<t[p].mn[0]||a[i].d[1]<t[p].mn[1])return 0;
	if(t[p].mx[0]<=a[i].d[0]&&t[p].mx[1]<=a[i].d[1])return t[p].siz;
	int res=query(lc(p),i)+query(rc(p),i);
	if(t[p].d[0]<=a[i].d[0]&&t[p].d[1]<=a[i].d[1])res+=t[p].tg;
	return res;
}
signed main() {
	n=rd(),k=rd();
	for(rint i=1;i<=n;++i)
		a[i].d[0]=rd(),a[i].d[1]=rd(),a[i].d[2]=rd();
	sort(a+1,a+n+1,cmp1);
	for(rint i=1;i<=n;++i)
		if(a[i].d[0]!=a[i-1].d[0]||a[i].d[1]!=a[i-1].d[1])b[++cnt]=a[i];
	root=build(1,cnt,0);
	sort(a+1,a+n+1,cmp2);
	for(rint l=1,r;l<=n;l=r+1) {
		r=l;
		while(a[r].d[2]==a[r+1].d[2]&&r<n)++r;
		for(rint i=l;i<=r;++i)insert(root,i);
		for(rint i=l;i<=r;++i)++ans[query(root,i)];
	}
	for(rint i=1;i<=n;++i)write(ans[i],'\n');
	fwrite(cltout,1,oh-cltout,stdout),oh=cltout;
	return 0;
}

例5

P5471 [NOI2019]弹跳

挺裸的,就是让你优化二维建图然后跑最短路.

\(\color{black}{\texttt{c}}\color{red}{\texttt{yn2006}}\) 用二维线段树过去的.可是我不会只好拿kdt搞了

kdt可以把一个矩形分成 \(O(\sqrt{n})\) 个矩形(节点),然后暴力连边就是跑dij \(O(m\sqrt{n}\log{M})\) 的了.

注意这里应该要拆点(试试不拆点为啥会挂就知道为啥要拆了)

因为kdt左右子树是不包含当前节点的,当不可以向这个区间连边时有可能可以向这个节点连边,所以要拆.

#define N 70005
#define M 1500005
int n,m;
int to[N],cmp_D,rt;
int dis[N<<1];
bool vis[N<<1];
struct dij{
	int u,dis;
	dij(){u=dis=0;}
	dij(int u_,int d_){u=u_,dis=d_;}
	inline bool operator < (const dij&t)const{return dis>t.dis;}
};
struct edge{int nxt,to,val;}e[30000000];
int head[N<<1],num_edge;
void addedge(int fr,int to,int val){
	++num_edge;
	e[num_edge].nxt=head[fr];
	e[num_edge].to=to;
	e[num_edge].val=val;
	head[fr]=num_edge;
}
struct kdt{
	int mx[2],mn[2],ch[2],d[2],id;
	#define ls(x) t[x].ch[0]
	#define rs(x) t[x].ch[1]
}t[N];
bool cmp(const kdt&a,const kdt&b){return a.d[cmp_D]<b.d[cmp_D];}
void pushup(int p){
	t[p].mx[0]=t[p].mn[0]=t[p].d[0];
	t[p].mx[1]=t[p].mn[1]=t[p].d[1];
	if(ls(p))
		t[p].mx[0]=max(t[p].mx[0],t[ls(p)].mx[0]),
		t[p].mx[1]=max(t[p].mx[1],t[ls(p)].mx[1]),
		t[p].mn[0]=min(t[p].mn[0],t[ls(p)].mn[0]),
		t[p].mn[1]=min(t[p].mn[1],t[ls(p)].mn[1]);
	if(rs(p))
		t[p].mx[0]=max(t[p].mx[0],t[rs(p)].mx[0]),
		t[p].mx[1]=max(t[p].mx[1],t[rs(p)].mx[1]),
		t[p].mn[0]=min(t[p].mn[0],t[rs(p)].mn[0]),
		t[p].mn[1]=min(t[p].mn[1],t[rs(p)].mn[1]);
}
int build(int l,int r,int D){
	int mid=(l+r)>>1;addedge(mid,mid+n,0);
	cmp_D=D,nth_element(t+l,t+mid,t+r+1,cmp),to[t[mid].id]=mid;
	if(l!=mid)ls(mid)=build(l,mid-1,D^1),addedge(mid,ls(mid),0);
	if(r!=mid)rs(mid)=build(mid+1,r,D^1),addedge(mid,rs(mid),0);
	return pushup(mid),mid;
}
void add(int id,int val,int l,int r,int d,int u,int p){
	if(!p)return;
	if(l<=t[p].mn[0]&&t[p].mx[0]<=r&&d<=t[p].mn[1]&&t[p].mx[1]<=u)return addedge(id,p,val);
	if(r<t[p].mn[0]||l>t[p].mx[0]||u<t[p].mn[1]||d>t[p].mx[1])return;
	if(l<=t[p].d[0]&&t[p].d[0]<=r&&d<=t[p].d[1]&&t[p].d[1]<=u)addedge(id,p+n,val);
	add(id,val,l,r,d,u,ls(p)),add(id,val,l,r,d,u,rs(p));
}
void Dij(){
	priority_queue<dij>q;
	memset(dis,0x3f,sizeof(dis));
	q.push(dij(to[1]+n,dis[to[1]+n]=0));
	while(!q.empty()){
		dij now=q.top();q.pop();
		int u=now.u;
		if(vis[u])continue;
		vis[u]=1;
		for(int i=head[u];i;i=e[i].nxt){
			int v=e[i].to;
			if(dis[v]>dis[u]+e[i].val){
				dis[v]=dis[u]+e[i].val;
				if(!vis[v])q.push(dij(v,dis[v]));
			}
		}
	}
}
signed main(){
	n=read(),m=read(),read(),read();
	for(int i=1;i<=n;++i)t[i].d[0]=read(),t[i].d[1]=read(),t[i].id=i;
	rt=build(1,n,0);
	for(int i=1;i<=m;++i){
		int p=read(),t=read(),l=read(),r=read(),d=read(),u=read();
		add(to[p]+n,t,l,r,d,u,rt);
	}
	Dij();
	for(int i=2;i<=n;++i)printf("%d\n",dis[to[i]+n]);
	return 0;
}

就在我写完之后忽然发现总边数不会开,忽然发现空间只有128MB...

我想把出题人a了!!!居然卡掉了这种做法.

那就不能暴力连边再跑最短路了

考虑边跑dij边用kdt模拟连边松弛,然后还可以剪枝了!(可以少连一些边,当这个节点已经小于松弛后的最小值直接return)

#define N 70005
#define M 1500005
int n,m;
int to[N],cmp_D,rt;
int dis[N<<1];
bool vis[N<<1];
int nxt[M],head[N],cnt,ver[M];
int z[N];
int L[M],R[M],D[M],U[M],W[M];
struct kdt{
	int mx[2],mn[2],ch[2],d[2],id;
	#define ls(x) t[x].ch[0]
	#define rs(x) t[x].ch[1]
}t[N];
struct dij{
	int u,dis;
	dij(){u=dis=0;}
	dij(int u_,int d_){u=u_,dis=d_;}
	inline bool operator < (const dij&t)const{return dis>t.dis;}
};
bool cmp(const kdt&a,const kdt&b){return a.d[cmp_D]<b.d[cmp_D];}
void pushup(int p){
	t[p].mx[0]=t[p].mn[0]=t[p].d[0];
	t[p].mx[1]=t[p].mn[1]=t[p].d[1];
	if(ls(p))
		t[p].mx[0]=max(t[p].mx[0],t[ls(p)].mx[0]),
		t[p].mx[1]=max(t[p].mx[1],t[ls(p)].mx[1]),
		t[p].mn[0]=min(t[p].mn[0],t[ls(p)].mn[0]),
		t[p].mn[1]=min(t[p].mn[1],t[ls(p)].mn[1]);
	if(rs(p))
		t[p].mx[0]=max(t[p].mx[0],t[rs(p)].mx[0]),
		t[p].mx[1]=max(t[p].mx[1],t[rs(p)].mx[1]),
		t[p].mn[0]=min(t[p].mn[0],t[rs(p)].mn[0]),
		t[p].mn[1]=min(t[p].mn[1],t[rs(p)].mn[1]);
}
int build(int l,int r,int D){
	int mid=(l+r)>>1;
	cmp_D=D,nth_element(t+l,t+mid,t+r+1,cmp),to[t[mid].id]=mid;
	if(l!=mid)ls(mid)=build(l,mid-1,D^1);
	if(r!=mid)rs(mid)=build(mid+1,r,D^1);
	return pushup(mid),mid;
}
void get(int p,int l,int r,int d,int u,int lim){
	if(!p||dis[p]<=lim)return;
	if(l<=t[p].mn[0]&&t[p].mx[0]<=r&&d<=t[p].mn[1]&&t[p].mx[1]<=u)return z[++z[0]]=p,void();
	if(r<t[p].mn[0]||l>t[p].mx[0]||u<t[p].mn[1]||d>t[p].mx[1])return;
	if(l<=t[p].d[0]&&t[p].d[0]<=r&&d<=t[p].d[1]&&t[p].d[1]<=u)z[++z[0]]=p+n;
	get(ls(p),l,r,d,u,lim),get(rs(p),l,r,d,u,lim);
}
void Dij(){
	priority_queue<dij>q;
	memset(dis,0x3f,sizeof(dis));
	q.push(dij(to[1]+n,dis[to[1]+n]=0));
	while(!q.empty()){
		dij now=q.top();q.pop();
		int u=now.u;
		if(vis[u])continue;
		vis[u]=1;
		if(u<=n){
			if(ls(u)&&dis[ls(u)]>dis[u])dis[ls(u)]=dis[u],q.push(dij(ls(u),dis[ls(u)]));
			if(rs(u)&&dis[rs(u)]>dis[u])dis[rs(u)]=dis[u],q.push(dij(rs(u),dis[rs(u)]));
			if(dis[u+n]>dis[u])dis[u+n]=dis[u],q.push(dij(u+n,dis[u+n]));
		}else{
			for(int i=head[u-n];i;i=nxt[i]){
				int t=ver[i],V=dis[u]+W[t];
				z[0]=0,get(rt,L[t],R[t],D[t],U[t],V);
				for(int j=1;j<=z[0];++j)
					if(dis[z[j]]>V)dis[z[j]]=V,q.push(dij(z[j],dis[z[j]]));
			}
		}
	}
}
signed main(){
	n=read(),m=read(),read(),read();
	for(int i=1;i<=n;++i)t[i].d[0]=read(),t[i].d[1]=read(),t[i].id=i;
	rt=build(1,n,0);
	for(int i=1;i<=m;++i){
		int x=read();
		ver[++cnt]=i,nxt[cnt]=head[to[x]],head[to[x]]=cnt;
		W[i]=read(),L[i]=read(),R[i]=read(),D[i]=read(),U[i]=read();
	}
		
	Dij();
	for(int i=2;i<=n;++i)printf("%d\n",dis[to[i]+n]);
	return 0;
}

一件事情

某天无意间翻倒这个帖子 link

国家队选手jmr说替罪羊重构复杂度是假的

【upd on 2021.2.18】翻到了 OI-wiki。skip2004说可以每隔 \(O(\sqrt{n})\) 次操作重构整棵树,并且复杂度为 \(O(n\sqrt{n}\log n)\) 。应该是对的。

posted @ 2020-11-02 12:15  zzctommy  阅读(199)  评论(0编辑  收藏  举报