Tsinsen A1486. 树(王康宁)

Description

一棵树,问至少有 \(k\) 个黑点的路径最大异或和.

Sol

点分治.

用点分治找重心控制树高就不说了,主要是对答案的统计的地方.

将所有路径按点的个数排序.

可以发现当左端点递增的时候右端点单调递减,时刻满足Trie树里的所有元素都是合法的即可,不断把右端点丢进去,用左端点统计答案.

主要跨越根的时候根的贡献计算了两次,需要删掉一次.

对于需要满足不是一颗子树,可以将Trie树上的节点打一个标记,表示这个节点及其子节点都是在某子树下的路径,子树个数大于1的时候这个标记就没用了.

我在维护标记的时候标记位置打错了...居然有95...调了好长时间QAQ...

Code

#include <bits/stdc++.h>
using namespace std;

#define debug(a) cout<<#a<<"="<<a<<" "
const int N = 1e5+50;
const int M = 31;

int n,k,kk,rt,ans=-1;
int pow2[M];
int bl[N],v[N],sz[N],t[N];
vector< int > g[N];
int usd[N];

struct pr { int x,y,z; };
bool operator < (const pr &a,const pr &b) { return a.x<b.x; }
vector< pr > S;

struct Trie {
	int cnt,rt;
	int ch[N*M][2],s[N*M],bl[N*M];
	
	int GetNode() { cnt++;ch[cnt][0]=ch[cnt][1]=s[cnt]=0;return cnt; }
	void init() {
		cnt=0,rt=GetNode();
	}
	void insert(int x,int fr) {
		int o=rt,r;
		for(int i=M-1;~i;i--) {
			if(x&pow2[i]) r=1;else r=0;
			if(!ch[o][r]) ch[o][r]=GetNode(),bl[ch[o][r]]=fr;
			else bl[ch[o][r]]=bl[ch[o][r]]==fr ? fr : 0;
			o=ch[o][r],s[o]++;
		}
	}
	int getv(int x,int fr) {
		int o=rt,r,res=0;
		if(!ch[rt][0] && !ch[rt][1]) return -1;
		for(int i=M-1;~i;i--) {
			if(x&pow2[i]) r=1;else r=0;
			if(s[ch[o][r^1]] && bl[ch[o][r^1]]!=fr) res|=pow2[i],r^=1;
			if(bl[ch[o][r]]==fr) return -1;
			o=ch[o][r];
		}return res;
	}
}py;
inline int in(int x=0,char ch=getchar()) { while(ch>'9' || ch<'0') ch=getchar();
	while(ch>='0' && ch<='9') x=x*10+ch-'0',ch=getchar();return x; }

void GetRoot(int u,int fa,int nn) {
	t[u]=0,sz[u]=1;
	for(vector< int >::iterator i=g[u].begin();i!=g[u].end();i++) 
		if((*i)!=fa && !usd[(*i)]) GetRoot(*i,u,nn),t[u]=max(t[u],sz[*i]),sz[u]+=sz[*i];
	t[u]=max(t[u],nn-sz[u]);
	if(t[u]<t[rt]) rt=u;
}
void GetS(int u,int fa,int c,int vv,int ff) {
	S.push_back((pr){ c,vv,ff });
	if(c>=k) ans=max(ans,vv); 
	for(vector< int >::iterator i=g[u].begin();i!=g[u].end();i++)
		if((*i)!=fa && !usd[*i]) GetS(*i,u,c+bl[*i],vv^v[*i],ff);
}
void GetAns(int u,int fa,int nn) {
	usd[u]=1,py.init(),S.clear();if(bl[u]>=k) ans=max(ans,v[u]);
	for(vector< int >::iterator i=g[u].begin();i!=g[u].end();i++)
		if((*i)!=fa && !usd[(*i)]) GetS((*i),u,bl[u]+bl[(*i)],v[(*i)],*i);
	sort(S.begin(),S.end());
	
//	cout<<u<<" : "<<nn<<endl;
//	for(int i=0;i<(int)S.size();i++) cout<<S[i].x<<" "<<S[i].y<<" "<<S[i].z<<endl;
	
	int lim=S.size(),l=0,r=lim-1;
	for(;l<lim;l++) {
		while(l<r && S[l].x+S[r].x>=k+bl[u]) py.insert(S[r].y,S[r].z),r--;
		ans=max(ans,py.getv(S[l].y^v[u],S[l].z));
//		debug(l),debug(r),debug(ans)<<endl;
	}
//	debug(ans)<<endl;
//	cout<<"-------------------------"<<endl;
	
	int ss;
	for(vector< int >::iterator i=g[u].begin();i!=g[u].end();i++) 
		if((*i)!=fa && !usd[(*i)]) rt=0,ss=sz[(*i)]>sz[u] ? nn-sz[u] : sz[*i],GetRoot((*i),u,ss),GetAns(rt,rt,ss);
}
int main() {
	n=in(),k=in();
	for(int i=1;i<=n;i++) bl[i]=in();
	for(int i=1;i<=n;i++) v[i]=in();
	for(int i=1,u,v;i<n;i++) u=in(),v=in(),g[u].push_back(v),g[v].push_back(u);
	
	pow2[0]=1;for(int i=1;i<M;i++) pow2[i]=pow2[i-1]<<1;
//	for(int i=0;i<M;i++) cout<<pow2[i]<<endl;
	rt=0,t[rt]=n+1,GetRoot(1,1,n),GetAns(rt,rt,n);
	
	cout<<ans<<endl;
	return 0;
}

 

posted @ 2016-12-24 14:49  北北北北屿  阅读(177)  评论(0编辑  收藏  举报