洛谷 P5984 - [PA2019]Podatki drogowe(点分治+主席树+随机二分)

洛谷题面传送门

首先点分治,对于重心为 \(x\) 的连通块,我们用主席树维护出 \(x\) 到连通块内每个点所组成的数,具体来说,主席树上下标为 \(1\) 的位置存储将权值转成 \(n\) 进制后 \(n^1\) 位上的值,下标为 \(2,3,\cdots,n\) 位上的值同理。这样显然我们可以通过主席树上二分在 \(\log n\) 的时间内比较两个 \(n\)\(n\) 进制数的大小。然后我们将该连通块内所有权值存到一个系数 \(1\) 的集合内,然后再将所有属于以 \(x\) 为根的同一子树内的所有权值存到一个系数为 \(-1\) 的集合内,这样相当于一个数 \(X\) 的出现次数就是对于所有集合,计算集合内有多少对数和等于 \(X\),乘上集合的系数再求和。

二分答案。check 是容易的,直接在每个集合内进行 two pointers 即可,由于集合大小总和是 \(n\log n\) 的,而比较大小是 \(\log n\) 的,所以单次 check 是两只 log 的,现在比较麻烦的地方在于,值域是 \(n^n\) 的,因此直接二分需要进行 \(n\) 次二分,实在不能接受。不过需要认识到一个清醒的事实,我们二分的对象只有 \(n^2\) 个,有什么方法使得二分次数达到 \(\log n\) 呢?一个很直接的想法是二分排名然后根据排名推回对应的数,不过这不就是原问题吗(笑),显然不可行,这时有个 trick:随机二分,即在对应区间中等概率随机一个 \(mid\) 将区间分成两半,那么又该怎样等概率呢?每次二分到一个区间 \([l,r]\),我们先将所有集合中的元素打乱到一起并排序,对于每个权值 \(X\),我们找到权值在 \([l-X,r-X]\) 中的权值组成的区间,这显然是可以 two pointers 的。然后对这些区间进行带权随机,即将区间长度为随机的权随机出左边的部分,然后再在对应区间中随机出右边的部分。

时间复杂度 \(n\log^3n\)

upd:咕完了。

const int MAXN=2.5e4;
const int MAXC=5e5;
const int MAXP=1e7;
const int MOD=1e9+7;
const int INF=0x3f3f3f3f;
const u64 BS=131;
mt19937_64 rng(time(0));
int qpow(int x,int e){int ret=1;for(;e;e>>=1,x=1ll*x*x%MOD)if(e&1)ret=1ll*ret*x%MOD;return ret;}
int n,k,hd[MAXN+5],to[MAXN*2+5],val[MAXN*2+5],nxt[MAXN*2+5],ec;
void adde(int u,int v,int w){to[++ec]=v;val[ec]=w;nxt[ec]=hd[u];hd[u]=ec;}
u64 pw[MAXN+5];int pwn[MAXN+5];
namespace HJTree{
	struct node{int ch[2],val;u64 hs;}s[MAXP+5];
	int ncnt;
	int modify(int k,int l,int r,int p,int v){
		int z=++ncnt;s[z]=s[k];s[z].val+=v;s[z].hs+=v*pw[p];
		if(l==r)return z;int mid=l+r>>1;
		if(p<=mid)s[z].ch[0]=modify(s[k].ch[0],l,mid,p,v);
		else s[z].ch[1]=modify(s[k].ch[1],mid+1,r,p,v);
		return z;
	}
	int _getcmp1(int k1,int k2,int l,int r){
		if(l==r)return (s[k1].val<s[k2].val)?-1:1;
		int mid=l+r>>1;
		if(s[s[k1].ch[1]].hs==s[s[k2].ch[1]].hs)return _getcmp1(s[k1].ch[0],s[k2].ch[0],l,mid);
		else return _getcmp1(s[k1].ch[1],s[k2].ch[1],mid+1,r);
	}
	int getcmp1(int k1,int k2){
		if(s[k1].hs==s[k2].hs)return 0;
		return _getcmp1(k1,k2,1,n);
	}
	int _getcmp2(int k1,int k2,int k3,int l,int r){
		if(l==r)return (s[k1].val+s[k2].val<s[k3].val)?-1:1;
		int mid=l+r>>1;
		if(s[s[k1].ch[1]].hs+s[s[k2].ch[1]].hs==s[s[k3].ch[1]].hs)
			return _getcmp2(s[k1].ch[0],s[k2].ch[0],s[k3].ch[0],l,mid);
		else return _getcmp2(s[k1].ch[1],s[k2].ch[1],s[k3].ch[1],mid+1,r);
	}
	int getcmp2(int k1,int k2,int k3){
		if(s[k1].hs+s[k2].hs==s[k3].hs)return 0;
		return _getcmp2(k1,k2,k3,1,n);
	}
	int merge(int x,int y,int l,int r){
		if(!x||!y)return x+y;int z=++ncnt,mid=l+r>>1;
		s[z].hs=s[x].hs+s[y].hs;s[z].val=s[x].val+s[y].val;
		if(l==r)return z;
		s[z].ch[0]=merge(s[x].ch[0],s[y].ch[0],l,mid);
		s[z].ch[1]=merge(s[x].ch[1],s[y].ch[1],mid+1,r);
		return z;
	}
}using namespace HJTree;
vector<pii>vec[MAXC+5];vector<pii>all;
int vcnt,rt[MAXN+5],dis[MAXN+5];
namespace CDT{
	int siz[MAXN+5],mx[MAXN+5],cent,vis[MAXN+5];
	vector<int>pts;
	void findcent(int x,int f,int totsz){
		siz[x]=1;mx[x]=0;
		for(int e=hd[x];e;e=nxt[e]){
			int y=to[e];if(y==f||vis[y])continue;
			findcent(y,x,totsz);siz[x]+=siz[y];chkmax(mx[x],siz[y]);
		}chkmax(mx[x],totsz-siz[x]);
		if(mx[x]<mx[cent])cent=x;
	}
	void dfs_calc(int x,int f,int r){
		vec[vcnt].pb(mp(rt[x],r));all.pb(mp(rt[x],dis[x]));
		for(int e=hd[x];e;e=nxt[e]){
			int y=to[e],z=val[e];if(y==f||vis[y])continue;
			rt[y]=modify(rt[x],1,n,z,1);
			dis[y]=(dis[x]+pwn[z])%MOD;dfs_calc(y,x,r);
		}
	}
	int calcsiz(int x,int f){
		int tot=1;
		for(int e=hd[x];e;e=nxt[e]){
			int y=to[e];if(y==f||vis[y])continue;
			tot+=calcsiz(y,x);
		}return tot;
	}
	void divcent(int x){
		vis[x]=1;rt[x]=dis[x]=0;++vcnt;vec[vcnt].pb(mp(rt[x],x));
		for(int e=hd[x];e;e=nxt[e]){
			int y=to[e],z=val[e];if(vis[y])continue;
			dis[y]=pwn[z];rt[y]=modify(rt[x],1,n,z,1);dfs_calc(y,x,y);
		}
		for(int e=hd[x];e;e=nxt[e]){
			int y=to[e];if(vis[y])continue;
			cent=0;findcent(y,x,calcsiz(y,x));divcent(cent);
		}
	}
}using namespace CDT;
int calc(int pt){
	int sum=0;static int cnt[MAXN+5];
	for(int i=1;i<=vcnt;i++){
		int cur=0;
		for(int j=(int)(vec[i].size())-1;~j;--j){
			while(cur<j&&getcmp2(vec[i][cur].fi,vec[i][j].fi,pt)<=0)
				cnt[vec[i][cur].se]++,++cur;
			while(cur>j+1)--cur,cnt[vec[i][cur].se]--;
			sum+=cur-cnt[vec[i][j].se];
		}
		for(int j=0;j<cur;j++)cnt[vec[i][j].se]--;
	}
	return sum;
}
int main(){
	scanf("%d%d",&n,&k);for(int i=(pw[0]=1);i<=n;i++)pw[i]=pw[i-1]*BS;
	for(int i=1,u,v,w;i<n;i++)scanf("%d%d%d",&u,&v,&w),adde(u,v,w),adde(v,u,w);
	for(int i=(pwn[0]=1);i<=n;i++)pwn[i]=1ll*pwn[i-1]*n%MOD;
	mx[0]=INF;findcent(1,0,n);divcent(cent);
	int _inf=modify(0,1,n,n,n+1);all.pb(mp(_inf,INF));all.pb(mp(0,0));
	for(int i=1;i<=vcnt;i++)sort(vec[i].begin(),vec[i].end(),[&](pii x,pii y){return getcmp1(x.fi,y.fi)<0;});
	sort(all.begin(),all.end(),[&](pii x,pii y){return getcmp1(x.fi,y.fi)<0;});
	all.resize(unique(all.begin(),all.end())-all.begin());
	int L=0,R=all.back().fi,res=0;
	while(clock()<4.8*CLOCKS_PER_SEC){
		static pii lim[MAXC+5];
		for(int i=all.size()-1,p1=0,p2=0;~i;i--){
			while(p2<all.size()&&getcmp2(all[i].fi,all[p2].fi,R)<0)++p2;
			while(p1<all.size()&&getcmp2(all[i].fi,all[p1].fi,L)<=0)++p1;
			lim[i]=mp(p1,p2);
		}
		ll sum=0;
		for(int i=0;i<all.size();i++)sum+=max(lim[i].se-lim[i].fi,0);
		if(!sum)break;
		ll v=rng()%sum+1;int X=-1,Y=-1,pt=0;
		for(int i=0;i<all.size();i++)if(lim[i].fi<lim[i].se){
			if(v<=lim[i].se-lim[i].fi){X=i;Y=lim[i].fi+v-1;break;}
			else v-=(lim[i].se-lim[i].fi);
		}
		pt=merge(all[X].fi,all[Y].fi,1,n);
		if(calc(pt)>=k)res=(all[X].se+all[Y].se)%MOD,R=pt;
		else L=pt;
	}
	printf("%d\n",res);
	return 0;
}
posted @ 2022-08-10 17:52  tzc_wk  阅读(147)  评论(0)    收藏  举报