Solution Set - 点分治


A[POJ1741].给定一棵树,边有权,求长度不超过\(k\)的路径数目。
B[HDU4871].给定一张图,边有权,求它的最短路径树上恰含\(k\)个点的路径中最长路径的长度及数目。
C[HDU4812].给定一棵树,点有权,求字典序最小的一个点对,其路径上的所有点权之积模\(100003\)等于\(k\)
D[HDU5469].给定一棵树,点上有字母,给定一个询问串,求是否存在一条路径,其所有点的字母依次连接为询问串。
E[HDU4670].给定一棵树,点有权,点权的质因数集合给定(大小\(\le30\)),求所有点权乘积为立方数的路径数目。
F[HDU5664].给定一棵树,边有权,对所有不具有祖先关系的无序点对求距离,求距离中的第\(k\)大值。
G[HDU5977].给定一棵树,点有类型(数目\(\le10\)),求所有类型的点都出现的路径数。
H[HDU5314].给定一棵树,点有权,求点权极差不超过\(d\)的路径数目。
I[HDU5909].给定一棵树,点有权(\(\lt2^m,m\le10\)),对\(0\le i \lt 2^m\),求点权异或和为\(i\)的连通块数目。
J[HDU5102].给定一棵树,求边数最小的\(k\)条路径的边数之和。
K[洛谷P6329].给定一棵树,在线修改点权,询问距离某个点的距离不超过\(k\)的点权和。
L[SPOJ-QTREE4].给定一棵树,边有权,点初始全为白,修改点的颜色,询问白点最远距离。
M[SPOJ-QTREE5].给定一棵树,边有权,点初始全为白,修改点的颜色,询问某个点到最近的白点的距离。
N[HDU5016].给定一棵树,边有权,某些点初始有标记。新标记一个点,使得以其为最近标记点的点数最大化。
O[HDU5571].给定一棵树,边有权,点有权,求所有点对的距离与点权异或值之积的和。


A点分治模板。
B还是模板。
C依然是模板。
D要用哈希优化,其它还是模板。
E用三进制状压,其它还是模板。
F二分,注意预处理点分治的结果。
G状压,统计答案时直接暴力枚举。
H用树状数组优化,比较板。
I点分治,求DFS序转化为序列DP,具体看代码。
J和F差不多。
K点分树,用树状数组维护。
L点分树,用两个堆维护最大值和次大值,需要卡常。
M点分树,和L差不多的维护。
N先求出每个点的到已给标记点的最近距离,再点分治,需要\(d_u+d_v<dis_v\),即\(d_u<dis_v-d_v\),用二分求个数。
O按位处理,每一位在点分树上考虑,记录每棵子树0,1的个数及距离和,需要卡常。


点击查看A题代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#include<cstring>
using namespace std;
const int N=1e4+5,M=1e7+5,INF=1<<30;
int n,k,ans;
int rt,sum,vis[N],siz[N],mx[N];
int head[N],ver[N<<1],nxt[N<<1],val[N<<1],tot;
void adde(int u,int v,int w){
	ver[++tot]=v;
	val[tot]=w;
	nxt[tot]=head[u];
	head[u]=tot;
}
int c[M+5];
void add(int x,int v){for(x++;x<=M;x+=x&-x)c[x]+=v;}
int ask(int x){int res=0;for(x++;x;x-=x&-x)res+=c[x];return res;}
queue<int> tmp,tmp1;
void calcsize(int u,int fa){
	siz[u]=1;mx[u]=0;
	for(int i=head[u];i;i=nxt[i])
		if(ver[i]!=fa&&!vis[ver[i]]){
			calcsize(ver[i],u);
			siz[u]+=siz[ver[i]];
			mx[u]=max(mx[u],siz[ver[i]]);
		}
	mx[u]=max(mx[u],sum-siz[u]);
	if(mx[u]<mx[rt])rt=u;
}
void calcdist(int u,int fa,int dis){
	if(dis>k)return;
	ans=ans+ask(k-dis);
	tmp.push(dis);tmp1.push(dis);
	for(int i=head[u];i;i=nxt[i])
		if(ver[i]!=fa&&!vis[ver[i]])
			calcdist(ver[i],u,dis+val[i]);
}
void dfs(int u,int fa){
	vis[u]=1;add(0,1);
	for(int i=head[u];i;i=nxt[i]){
		if(ver[i]==fa||vis[ver[i]])continue;
		calcdist(ver[i],u,val[i]);
		while(!tmp.empty()){
			add(tmp.front(),1);
			tmp.pop();
		}
	}
	while(!tmp1.empty()){
		add(tmp1.front(),-1);
		tmp1.pop();
	}add(0,-1);
	for(int i=head[u];i;i=nxt[i])
		if(ver[i]!=fa&&!vis[ver[i]]){
			sum=siz[ver[i]];rt=0;
			mx[rt]=INF;
			calcsize(ver[i],u);
			calcsize(rt,-1);
			dfs(rt,u);
		}
}
void init(){
	tot=ans=0;
	memset(head,0,sizeof(head));
	memset(vis,0,sizeof(vis));	
}
int main(){
	while(scanf("%d%d",&n,&k),n!=0&&k!=0){
		init();
		for(int i=1,u,v,w;i<n;i++){
			scanf("%d%d%d",&u,&v,&w);
			adde(u,v,w);adde(v,u,w);
		}
		rt=0;sum=n;mx[rt]=INF;
		calcsize(1,-1);
		calcsize(rt,-1);
		dfs(rt,-1);
		printf("%d\n",ans);
	}
	return 0;
}
点击查看B题代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1.2e5+5;
int Test,n,m,k;
struct Graph{
	int head[N],nxt[N],ver[N],val[N],tot=0;
	void add(int u,int v,int w){
		ver[++tot]=v;val[tot]=w;
		nxt[tot]=head[u];head[u]=tot;
	}
	void init(){
		tot=1;
		memset(head,0,sizeof(head));
	}
}G,T;
struct Dijkstra{
	int dis[N],vis[N],pre[N];
	priority_queue<pair<int,int> >Q;
	void solve(){
		memset(dis,0x3f,sizeof(dis));
		memset(vis,0,sizeof(vis));
		Q.push(make_pair(1,0));dis[1]=0;
		while(!Q.empty()){
			int u=Q.top().first;Q.pop();
			if(vis[u])continue;
			for(int i=G.head[u],v;i;i=G.nxt[i])
				if(dis[v=G.ver[i]]>dis[u]+G.val[i]||
				   dis[v]==dis[u]+G.val[i]&&u<G.ver[pre[v]]){
					dis[v]=dis[u]+G.val[i];pre[v]=i^1;
					Q.push(make_pair(v,-dis[v]));
				}
		}
	}
}Dijk;
struct Point_Divide{
	int sum,rt,sz[N],mx[N],vis[N],lim;
	ll ans,cnt,maxd[N],cntd[N];
	void csize(int u,int fa){
		sz[u]=1;mx[u]=0;
		for(int i=T.head[u],v;i;i=T.nxt[i])
			if((v=T.ver[i])!=fa&&!vis[v])
				csize(v,u),sz[u]+=sz[v],mx[u]=max(mx[u],sz[v]);
		mx[u]=max(mx[u],sum-sz[u]);
		if(mx[u]<mx[rt])rt=u;
	}
	void calc(int u,int fa,ll dis,int dep){
		if(dep<k&&cntd[k-dep-1]){
			if(dis+maxd[k-dep-1]>ans)ans=dis+maxd[k-dep-1],cnt=cntd[k-dep-1];
			else if(dis+maxd[k-dep-1]==ans)cnt+=cntd[k-dep-1];
		}
		for(int i=T.head[u],v;i;i=T.nxt[i])
			if((v=T.ver[i])!=fa&&!vis[v])calc(v,u,dis+T.val[i],dep+1);
	}
	void upd(int u,int fa,ll dis,int dep){
		lim=max(lim,dep);
		if(dis>maxd[dep])maxd[dep]=dis,cntd[dep]=1;
		else if(dis==maxd[dep])++cntd[dep];
		for(int i=T.head[u],v;i;i=T.nxt[i])
			if((v=T.ver[i])!=fa&&!vis[v])upd(v,u,dis+T.val[i],dep+1);
	}
	void solve(int u){
		maxd[0]=0;cntd[0]=1;lim=0;
		for(int i=T.head[u],v;i;i=T.nxt[i])
			if(!vis[v=T.ver[i]]){calc(v,u,T.val[i],1);upd(v,u,T.val[i],1);}
		for(int i=1;i<=lim;i++)maxd[i]=cntd[i]=0;
	}
	void dfs(int u){
		vis[u]=1;solve(u);
		for(int i=T.head[u],v;i;i=T.nxt[i])
			if(!vis[v=T.ver[i]]){sum=sz[v];rt=0;csize(v,-1);csize(rt,-1);dfs(rt);}
	}
	void work(){
		memset(maxd,0,sizeof(maxd));
		memset(cntd,0,sizeof(cntd));
		memset(vis,0,sizeof(vis));
		ans=cnt=0;mx[rt=0]=sum=n;
		csize(1,-1);csize(rt,-1);dfs(rt);
	}
}PD;
int main(){
	scanf("%d",&Test);
	while(Test--){
		G.init();T.init();
		scanf("%d%d%d",&n,&m,&k);
		for(int i=1,u,v,w;i<=m;i++){
			scanf("%d%d%d",&u,&v,&w);
			G.add(u,v,w);G.add(v,u,w);
		}
		Dijk.solve();
		for(int i=2;i<=n;i++){
			T.add(i,G.ver[Dijk.pre[i]],G.val[Dijk.pre[i]]);
			T.add(G.ver[Dijk.pre[i]],i,G.val[Dijk.pre[i]]);
		}
		PD.work();
		printf("%lld %lld\n",PD.ans,PD.cnt);
	}
	return 0;
}
点击查看C题代码
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+5,mod=1e6+3;
int n,k,sum,rt,siz[N],mx[N],vis[N];
int now,a[N],ansu,ansv,f[mod+5],inv[mod+5],z;
int head[N],nxt[N<<1],ver[N<<1],tot;
int tmp[N][2],num;
void add(int u,int v){ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;}
void calcsize(int u,int fa){
	siz[u]=1;mx[u]=0;
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v]){
			calcsize(v,u);siz[u]+=siz[v];
			mx[u]=max(mx[u],siz[v]);
		}
	mx[u]=max(mx[u],sum-siz[u]);
	if(mx[u]<mx[rt])rt=u;
}
void chk(int u,int v){
	if(u==v)return;
	if(u>v)swap(u,v);
	if(ansu==-1)ansu=u,ansv=v;
	else if(u<ansu||u==ansu&&v<ansv)ansu=u,ansv=v;
}
void calc(int u,int fa,int val){
	++num;tmp[num][0]=val;tmp[num][1]=u;
	if(1ll*val*a[now]%mod==k)chk(u,now);
	if(f[z=1ll*k*inv[val]%mod*inv[a[now]]%mod]!=-1)chk(u,f[z]);
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v])
			calc(v,u,1ll*val*a[v]%mod);
}
void dfs(int u,int fa){
	now=u;vis[u]=1;num=0;
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v]){
			int num0=num+1;calc(v,u,a[v]);
			for(int j=num0;j<=num;j++){
				if(f[tmp[j][0]]==-1)f[tmp[j][0]]=tmp[j][1];
				else f[tmp[j][0]]=min(f[tmp[j][0]],tmp[j][1]);
			}
		}
	for(int j=1;j<=num;j++)f[tmp[j][0]]=-1;
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v]){
			sum=siz[v];rt=0;
			calcsize(v,u);calcsize(rt,-1);
			dfs(rt,u);
		}
}
void init(){
	ansu=ansv=-1;num=tot=0;
	for(int i=1;i<=n;i++)head[i]=vis[i]=0;
}
int main(){
	memset(f,-1,sizeof(f));inv[1]=1;
	for(int i=2;i<mod;i++)inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
	while(scanf("%d%d",&n,&k)!=EOF){
		init();
		for(int i=1;i<=n;i++)scanf("%d",a+i);
		for(int i=1,u,v;i<n;++i){
			scanf("%d%d",&u,&v);
			add(u,v);add(v,u);
		}
		sum=n;rt=0;mx[rt]=1<<29;
		calcsize(1,-1);calcsize(rt,-1);dfs(rt,-1);
		if(ansu!=-1)printf("%d %d\n",ansu,ansv);
		else printf("No solution\n");
	}
	return 0;
}
点击查看D题代码
#pragma GCC optimize(2)
#pragma GCC optimize(3."Ofast","inline")
#include<bits/stdc++.h>
using namespace std;
typedef unsigned long long ull;
const int N=10005;
int T,n,len,ans,sum,rt,sz[N],mx[N],vis[N],y[N],a[2*N],f[2*N],cnt,p0,lenx;
char s[N],t[N];
int head[N],nxt[N<<1],ver[N<<1],tot;
void add(int u,int v){ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;}
ull pre[N],suf[N],base=131,p[N],x[2*N];
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return x*f;
}
inline void csize(int u,int fa){
	sz[u]=1;mx[u]=0;
	for(register int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v]){
			csize(v,u);sz[u]+=sz[v];
			mx[u]=max(mx[u],sz[v]);
		}
	mx[u]=max(mx[u],sum-sz[u]);
	if(mx[u]<mx[rt])rt=u;
}
inline void calc(int u,int fa,ull val,int dep){
	if(dep>len)return;
	p0=lower_bound(x+1,x+lenx+1,val)-x;
	if(x[p0]==val){y[++cnt]=p0;if(f[a[p0]])ans=1;}
	for(register int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v])
			calc(v,u,val+p[dep+1]*s[v],dep+1);
}
inline void dfs(int u,int fa){
	vis[u]=1;cnt=0;
	p0=lower_bound(x+1,x+lenx+1,1llu*s[u])-x;
	if(x[p0]==1llu*s[u])y[++cnt]=p0,f[p0]=1;
	for(register int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v]){
			int cnt0=cnt+1;calc(v,u,s[u]+base*s[v],1);
			for(int j=cnt0;j<=cnt;++j)f[y[j]]=1;
		}
	for(register int j=1;j<=cnt;++j)f[y[j]]=0;
	for(register int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v]){
			sum=sz[v];rt=0;
			csize(v,u);csize(rt,-1);dfs(rt,u);
		}
}
int main(){
	T=read();
	for(register int cas=1;cas<=T;++cas){
		ans=tot=0;
		for(int i=1;i<=n;i++)head[i]=vis[i]=0;
		n=read();
		for(register int i=1,u,v;i<n;++i){
			u=read();v=read();
			add(u,v);add(v,u);
		}
		scanf("%s%s",s+1,t+1);
		len=strlen(t+1);
		if(len==1){
			for(register int i=1;i<=n;++i)
				if(s[i]==t[1])ans=1;
		}
		else{
			p[0]=1;pre[0]=suf[len+1]=0;
			for(register int i=1;i<=len;++i){
				x[i]=pre[i]=pre[i-1]*base+t[i];
				p[i]=p[i-1]*base;
			}
			for(register int i=len;i>=1;--i)
				x[i+len]=suf[i]=suf[i+1]*base+t[i];
			sort(x+1,x+2*len+1);
			lenx=unique(x+1,x+2*len+1)-x-1;
			for(register int i=1;i<=len;++i){
				int p1=lower_bound(x+1,x+lenx+1,pre[i])-x,
					p2=lower_bound(x+1,x+lenx+1,suf[i])-x;
				a[p1]=p2;a[p2]=p1;
			}
			sum=n;rt=0;mx[rt]=1<<30;
			csize(1,-1);csize(rt,-1);dfs(rt,-1);
		}
		printf("Case #%d: ",cas);
		puts(ans?"Find":"Impossible");
	}
	return 0;
}
点击查看E题代码
#pragma GCC optimize(2)
#pragma GCC optimize(3."Ofast","inline")
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef vector<int> vi;
const int N=5e4+5;
int n,k,p[35],sum,rt,sz[N],mx[N],vis[N];
ll ans,pwr3[35],x;
int head[N],nxt[N<<1],ver[N<<1],tot;
void add(int u,int v){ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;}
struct P{ll v;}a[N];
P operator +(const P&a,const P&b){
	P c;c.v=0;
	for(int i=0;i<k;i++)c.v+=(a.v/pwr3[i]%3+b.v/pwr3[i]%3)%3*pwr3[i];
	return c;
}
P opp(P a){
	P b;b.v=0;
	for(int i=0;i<k;i++)b.v+=(3-a.v/pwr3[i]%3)%3*pwr3[i];
	return b;
}
bool operator <(const P&a,const P&b){return a.v<b.v;}
map<P,ll> M;
void csize(int u,int fa){
	sz[u]=1;mx[u]=0;
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v]){
			csize(v,u);sz[u]+=sz[v];
			mx[u]=max(mx[u],sz[v]);
		}
	mx[u]=max(mx[u],sum-sz[u]);
	if(mx[u]<mx[rt])rt=u;
}
void calc(int u,int fa,P val){
	ans+=M[opp(val)];
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v])calc(v,u,val+a[v]);
}
void upd(int u,int fa,P val){
	++M[val];
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v])upd(v,u,val+a[v]);
}
void dfs(int u,int fa){
	if(a[u].v==0)++ans;
	vis[u]=1;M.clear();M[P{0}]=1;
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v]){calc(v,u,a[u]+a[v]);upd(v,u,a[v]);}
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v]){
			sum=sz[v];rt=0;
			csize(v,u);csize(rt,-1);dfs(rt,u);
		}
}
int main(){
	pwr3[0]=1;
	for(int i=1;i<=30;i++)pwr3[i]=3*pwr3[i-1];
	while(scanf("%d",&n)!=EOF){
		ans=tot=0;
		for(int i=1;i<=n;i++)head[i]=vis[i]=a[i].v=0;
		scanf("%d",&k);
		for(int i=1;i<=k;i++)scanf("%d",p+i);
		for(int i=1;i<=n;i++){
			scanf("%lld",&x);
			for(int j=1,c=0;j<=k;j++){
				for(c=0;x%p[j]==0;c++)x/=p[j];
				a[i].v+=c%3*pwr3[j-1];
			}
		}
		for(int i=1,u,v;i<n;i++){
			scanf("%d%d",&u,&v);
			add(u,v);add(v,u);
		}
		sum=n;rt=0;mx[rt]=1<<30;
		csize(1,-1);csize(rt,-1);dfs(rt,-1);
		printf("%lld\n",ans);
	}
	return 0;
}
点击查看F题代码
#include<bits/stdc++.h>
using namespace std;
const int N=5e4+5;
int T,n,m,k,d,ans,cnt[N],sum,rt,sz[N],mx[N],vis[N],now0,now;
int head[N],nxt[N<<1],ver[N<<1],val[N<<1],tot;
void add(int u,int v,int w){
	ver[++tot]=v;val[tot]=w;
	nxt[tot]=head[u];head[u]=tot;
}
void csize(int u,int fa){
	sz[u]=1;mx[u]=0;
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v])
			csize(v,u),sz[u]+=sz[v],mx[u]=max(mx[u],sz[v]);
	mx[u]=max(mx[u],sum-sz[u]);
	if(mx[u]<mx[rt])rt=u;
}
struct node{int id,dis;};
vector<node> a[N];
bool operator <(const node&a,const node&b){return a.dis<b.dis;}
void calc(int u,int fa,int dist){
	a[now0].push_back(node{now,dist});
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v])calc(v,u,dist+val[i]);
}
void dfs(int u){
	vis[now0=u]=1;a[u].push_back(node{0,0});
	for(int i=head[u],v;i;i=nxt[i])
		if(!vis[v=ver[i]])now=v,calc(v,u,val[i]);
	sort(a[u].begin(),a[u].end());
	for(int i=head[u],v;i;i=nxt[i])
		if(!vis[v=ver[i]]){
			sum=sz[v];rt=0;
			csize(v,-1);csize(rt,-1);
			dfs(rt);
		}
}
void solve1(int u,int fa){
	int t=a[u].size();
	for(int l=0;l<t;l++)cnt[a[u][l].id]=0;
	for(int l=0,r=t-1;l<t;l++){
		while(r>=0&&a[u][l].dis+a[u][r].dis>=d)++cnt[a[u][r].id],--r;
		ans+=t-1-r-cnt[a[u][l].id];
	}
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa)solve1(v,u);
}
vector<int> D;
void solve2(int u,int fa,int dis){
	D.push_back(dis);
	ans-=upper_bound(D.begin(),D.end(),dis-d)-D.begin();
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa)solve2(v,u,dis+val[i]);
	D.pop_back();
}
int main(){
	mx[0]=1<<30;
	scanf("%d",&T);
	while(T--){
		tot=0;
		for(int i=1;i<=n;i++)head[i]=vis[i]=0,a[i].clear();
		scanf("%d%d%d",&n,&m,&k);
		for(int i=1,u,v,w;i<n;i++){
			scanf("%d%d%d",&u,&v,&w);
			add(u,v,w);add(v,u,w);
		}
		sum=n;rt=0;
		csize(1,-1);csize(rt,-1);dfs(rt);
		int L=1,R=5e8,res=-1;
		while(L<=R){
			d=L+R>>1;
			ans=0;solve1(m,-1);ans/=2;solve2(m,-1,0);
			if(ans>=k)res=d,L=d+1;
			else R=d-1;
		}
		if(res==-1)printf("NO\n");
		else printf("%d\n",res);
	}
	return 0;
}
点击查看G题代码
#include<bits/stdc++.h>
using namespace std;
const int N=5e4+5;
int n,k,col[N],f[N],x[N],cnt,sum,rt,sz[N],mx[N],vis[N];long long ans;
int head[N],nxt[N<<1],ver[N<<1],tot;
void add(int u,int v){ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;}
void csize(int u,int fa){
	sz[u]=1;mx[u]=0;
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v])
			csize(v,u),sz[u]+=sz[v],mx[u]=max(mx[u],sz[v]);
	mx[u]=max(mx[u],sum-sz[u]);
	if(mx[u]<mx[rt])rt=u;
}
void calc(int u,int fa,int st){
	int t=((1<<k)-1)^st;ans+=f[t];
	for(int i=st;i;i=(i-1)&st)ans+=f[i^t];
	x[++cnt]=st;
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v])calc(v,u,st|(1<<col[v]));
}
void dfs(int u){
	vis[u]=1;f[x[cnt=1]=(1<<col[u])]=1;
	for(int i=head[u],v;i;i=nxt[i])
		if(!vis[v=ver[i]]){
			int cnt0=cnt+1;calc(v,u,(1<<col[u])|(1<<col[v]));
			for(int j=cnt0;j<=cnt;j++)f[x[j]]++;
		}
	for(int i=0;i<(1<<k);i++)f[i]=0;
	for(int i=head[u],v;i;i=nxt[i])
		if(!vis[v=ver[i]]){
			sum=sz[v];rt=0;
			csize(v,u);csize(rt,u);dfs(rt);
		}
}
int main(){
	while(scanf("%d%d",&n,&k)!=EOF){
		ans=tot=0;
		for(int i=1;i<=n;i++)head[i]=vis[i]=0;
		for(int i=1;i<=n;i++)scanf("%d",col+i),--col[i];
		for(int i=1,u,v;i<n;i++){scanf("%d%d",&u,&v);add(u,v);add(v,u);}
		if(k==1){printf("%lld\n",1ll*n*n);continue;}
		sum=n;mx[rt=0]=1<<30;
		csize(1,-1);csize(rt,-1);dfs(rt);
		printf("%lld\n",2*ans);
	}
	return 0;
}
点击查看H题代码
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+5;
int T,n,m,k,d,p[N],q[N],sum,rt,sz[N],mx[N],vis[N],now;long long ans;
int head[N],nxt[N<<1],ver[N<<1],tot;
void adde(int u,int v){ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;}
void csize(int u,int fa){
	sz[u]=1;mx[u]=0;
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v])
			csize(v,u),sz[u]+=sz[v],mx[u]=max(mx[u],sz[v]);
	mx[u]=max(mx[u],sum-sz[u]);
	if(mx[u]<mx[rt])rt=u;
}
struct node{int min,max;}a[N];
bool operator <(const node&a,const node&b){
	return a.max==b.max?a.min<b.min:a.max<b.max;
}
int c[N];
void modify(int x,int v){for(;x<=m;x+=x&-x)c[x]+=v;}
int query(int x){int res=0;for(;x;x-=x&-x)res+=c[x];return res;}
void calc(int u,int fa,int mn,int mx){
	a[++k]={mn,mx};
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v])calc(v,u,min(mn,p[v]),max(mx,p[v]));
}
void dfs(int u){
	vis[u]=1;a[k=1]={p[u],p[u]};
	for(int i=head[u],v;i;i=nxt[i])
		if(!vis[v=ver[i]]){
			int k0=k+1;
			calc(v,u,min(p[u],p[v]),max(p[u],p[v]));
			sort(a+k0,a+k+1);
			for(int j=k0;j<=k;j++){
				int tmp=lower_bound(q+1,q+m+1,q[a[j].max]-d)-q-1;
				if(a[j].min>tmp)ans-=j-k0-query(tmp);modify(a[j].min,1);
			}
			for(int j=k0;j<=k;j++)modify(a[j].min,-1);
		}
	sort(a+1,a+k+1);
	for(int j=1;j<=k;j++){
		int tmp=lower_bound(q+1,q+m+1,q[a[j].max]-d)-q-1;
		if(a[j].min>tmp)ans+=j-1-query(tmp);modify(a[j].min,1);
	}
	for(int j=1;j<=k;j++)modify(a[j].min,-1);
	for(int i=head[u],v;i;i=nxt[i])
		if(!vis[v=ver[i]]){
			sum=sz[v];rt=0;
			csize(v,-1);csize(rt,-1);
			dfs(rt);
		}
}
int main(){
	mx[0]=1<<30;
	scanf("%d",&T);
	while(T--){
		ans=tot=0;
		for(int i=1;i<=n;i++)head[i]=vis[i]=0;
		scanf("%d%d",&n,&d);
		for(int i=1;i<=n;i++)scanf("%d",p+i),q[i]=p[i];
		sort(q+1,q+n+1);m=unique(q+1,q+n+1)-q-1;
		for(int i=1;i<=n;i++)p[i]=lower_bound(q+1,q+m+1,p[i])-q; 
		for(int i=1,u,v;i<n;i++){
			scanf("%d%d",&u,&v);
			adde(u,v);adde(v,u);
		}
		sum=n;rt=0;
		csize(1,-1);csize(rt,-1);dfs(rt);
		printf("%lld\n",ans*2);
	}
	return 0;
}
点击查看I题代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1145,mod=1e9+7;
int Test,n,m,a[N];
struct Graph{
	int head[N],nxt[N<<1],ver[N<<1],tot=0;
	void add(int u,int v){ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;}
	void init(){tot=1;memset(head,0,sizeof(head));}
}T;
struct Point_Divide{
	int sum,rt,sz[N],mx[N],vis[N];
	int R[N],dfn,id[N],f[N][N],ans[N];
	void csize(int u,int fa){
		sz[u]=1;mx[u]=0;
		for(int i=T.head[u],v;i;i=T.nxt[i])
			if((v=T.ver[i])!=fa&&!vis[v])
				csize(v,u),sz[u]+=sz[v],mx[u]=max(mx[u],sz[v]);
		mx[u]=max(mx[u],sum-sz[u]);
		if(mx[u]<mx[rt])rt=u;
	}
	void dfs2(int u,int fa){
		id[++dfn]=u;
		for(int i=T.head[u],v;i;i=T.nxt[i])
			if((v=T.ver[i])!=fa&&!vis[v])dfs2(v,u);
		R[u]=dfn;
	}
	void solve(int u){
		dfn=0;dfs2(u,-1);
		for(int i=0;i<=dfn+1;i++)for(int j=0;j<m;j++)f[i][j]=0;
		f[1][a[u]]=1;
		for(int i=2;i<=dfn;i++)
			for(int j=0;j<m;j++){
				f[i][j^a[id[i]]]=(f[i][j^a[id[i]]]+f[i-1][j])%mod;
				f[R[id[i]]][j]=(f[R[id[i]]][j]+f[i-1][j])%mod;
			}
		for(int j=0;j<m;j++)ans[j]=(ans[j]+f[dfn][j])%mod;
	}
	void dfs(int u){
		vis[u]=1;solve(u);
		for(int i=T.head[u],v;i;i=T.nxt[i])
			if(!vis[v=T.ver[i]]){sum=sz[v];rt=0;csize(v,-1);csize(rt,-1);dfs(rt);}
	}
	void work(){
		memset(ans,0,sizeof(ans));
		memset(vis,0,sizeof(vis));
		mx[rt=0]=sum=n;
		csize(1,-1);csize(rt,-1);dfs(rt);
	}
}PD;
int main(){
	scanf("%d",&Test);
	while(Test--){
		T.init();
		scanf("%d%d",&n,&m);
		for(int i=1;i<=n;i++)scanf("%d",a+i);
		for(int i=1,u,v;i<n;i++){
			scanf("%d%d",&u,&v);
			T.add(u,v);T.add(v,u);
		}
		PD.work();
		printf("%d",PD.ans[0]);
		for(int j=1;j<m;j++)printf(" %d",PD.ans[j]);
		printf("\n");
	}
	return 0;
}
点击查看J题代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5+5;
int T,n,k,cnt[N],cntl[N],sum,rt,sz[N],mx[N],vis[N],now0,now;ll length,ans;
int head[N],nxt[N<<1],ver[N<<1],tot;
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return x*f;
}
inline void add(int u,int v){ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;}
inline void csize(int u,int fa){
	sz[u]=1;mx[u]=0;
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v])
			csize(v,u),sz[u]+=sz[v],mx[u]=max(mx[u],sz[v]);
	mx[u]=max(mx[u],sum-sz[u]);
	if(mx[u]<mx[rt])rt=u;
}
struct node{int id,dis;};
vector<node> a[N];
bool operator <(const node&a,const node&b){return a.dis<b.dis;}
inline void calc(int u,int fa,int dist){
	a[now0].push_back(node{now,dist});
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v])calc(v,u,dist+1);
}
inline void dfs(int u){
	vis[now0=u]=1;a[u].push_back(node{0,0});
	for(int i=head[u],v;i;i=nxt[i])
		if(!vis[v=ver[i]])now=v,calc(v,u,1);
	sort(a[u].begin(),a[u].end());
	for(int i=head[u],v;i;i=nxt[i])
		if(!vis[v=ver[i]]){sum=sz[v];rt=0;csize(v,-1);csize(rt,-1);dfs(rt);}
}
inline void solve(int d){
	for(int u=1;u<=n;u++){
		int t=a[u].size();long long len=0;
		for(int l=0;l<t;++l)cnt[a[u][l].id]=cntl[a[u][l].id]=0;
		for(int l=0;l<t;++l){
			++cnt[a[u][l].id];
			cntl[a[u][l].id]+=a[u][l].dis;
			len+=a[u][l].dis;
		}
		for(int l=0,r=t-1;l<t;++l){
			while(r>=0&&a[u][l].dis+a[u][r].dis>d){
				--cnt[a[u][r].id];
				cntl[a[u][r].id]-=a[u][r].dis;
				len-=a[u][r].dis;
				--r;
			}
			int tmp=r+1-cnt[a[u][l].id];
			ans+=tmp;length+=len-cntl[a[u][l].id]+1ll*tmp*a[u][l].dis;
		}
	}
}
int main(){
	mx[0]=1<<30;
	T=read();
	while(T--){
		tot=0;
		for(int i=1;i<=n;++i)head[i]=vis[i]=0,a[i].clear();
		n=read();k=read();
		for(int i=1,u,v;i<n;++i){u=read();v=read();add(u,v);add(v,u);}
		sum=n;rt=0;csize(1,-1);csize(rt,-1);dfs(rt);
		int L=1,R=n,res=-1;
		while(L<=R){
			int mid=L+R>>1;ans=0;solve(mid);ans/=2;
			if(ans>=k)res=mid,R=mid-1;else L=mid+1;
		}
		ans=length=0;solve(res);
		printf("%lld\n",length/2-(ans/2-k)*res);
	}
	return 0;
}
点击查看K题代码
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+5;
int n,m,a[N],head[N],nxt[N<<1],ver[N<<1],etot;
void add(int u,int v){ver[++etot]=v;nxt[etot]=head[u];head[u]=etot;}
int sum,rt,sz[N],mx[N],vis[N],f[N],dis[N],fa[N][20];
void dfs0(int u){
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa[u][0])
			dis[v]=dis[u]+1,fa[v][0]=u,dfs0(v);
}
void pre(){
	for(int j=1;(1<<j)<=n;j++)
		for(int i=1;i<=n;i++)if(fa[i][j-1]!=-1)
			fa[i][j]=fa[fa[i][j-1]][j-1];
}
int dist(int u,int v){
	int g,x=u,y=v;
	if(dis[u]<dis[v])swap(u,v);
	int d=dis[u]-dis[v];
	for(int i=19;i>=0;i--)
		if((d>>i)&1)u=fa[u][i];
	if(u==v)g=u;
	else{
		for(int i=19;i>=0;i--)
			if(fa[u][i]!=fa[v][i])u=fa[u][i],v=fa[v][i];
		g=fa[u][0];
	}
	return dis[x]+dis[y]-2*dis[g];
}
vector<int> c[N][2];
void upd(int u,int op,int x,int v){for(++x;x<=sz[u];x+=x&-x)c[u][op][x]+=v;}
int ask(int u,int op,int x){int res=0;for(x=min(x+1,sz[u]);x;x-=x&-x)res+=c[u][op][x];return res;}
void modify(int x,int v){
	for(int u=x;u;u=f[u])upd(u,0,dist(u,x),v);
	for(int u=x;f[u];u=f[u])upd(u,1,dist(f[u],x),v);
}
void csize(int u,int fa){
	sz[u]=1;mx[u]=0;
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v])
			csize(v,u),sz[u]+=sz[v],mx[u]=max(mx[u],sz[v]);
	mx[u]=max(mx[u],sum-sz[u]);
	if(mx[u]<mx[rt])rt=u;
}
void dfs(int u){
	vis[u]=1;sz[u]=sum+1;
	c[u][0].resize(sz[u]+1);c[u][1].resize(sz[u]+1);
	for(int i=head[u],v;i;i=nxt[i])
		if(!vis[v=ver[i]]){
			sum=sz[v];rt=0;
			csize(v,-1);csize(rt,-1);
			f[rt]=u;dfs(rt);
		}
}
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++)scanf("%d",a+i);
	for(int i=1,u,v;i<n;i++){
		scanf("%d%d",&u,&v);
		add(u,v);add(v,u);
	}
	memset(fa,-1,sizeof(fa));dfs0(1);pre();
	sum=mx[rt=0]=n;csize(1,-1);csize(rt,-1);dfs(rt);
	for(int i=1;i<=n;i++)modify(i,a[i]);
	for(int i=1,op,x,y,ans=0;i<=m;i++){
		scanf("%d%d%d",&op,&x,&y);
		x^=ans;y^=ans;
		if(op==0){
			ans=ask(x,0,y);
			for(int u=x;f[u];u=f[u]){
				int dis=dist(x,f[u]);
				if(y>=dis)ans+=ask(f[u],0,y-dis)-ask(u,1,y-dis);
			}
			printf("%d\n",ans);
		}
		else modify(x,y-a[x]),a[x]=y;
	}
	return 0;
}
点击查看L题代码
#pragma GCC optimize(2)
#pragma GCC optimize(3)
#include<bits/stdc++.h>
using namespace std;
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return x*f;
}
const int N=1e5+5;
int n,m,cnt,col[N],head[N],nxt[N<<1],ver[N<<1],val[N<<1],etot;char ch;int x;
inline void add(int u,int v,int w){
	ver[++etot]=v;val[etot]=w;
	nxt[etot]=head[u];head[u]=etot;
}
int sum,rt,sz[N],mx[N],vis[N],fa[N],dep[N],dis[N],anc[N][17];
inline void dfs0(int u){
	for(register int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=anc[u][0]){
			dep[v]=dep[u]+1;dis[v]=dis[u]+val[i];
			anc[v][0]=u;dfs0(v);
		}
}
inline void pre(){
	for(register int j=1;(1<<j)<=n;++j)
		for(register int i=1;i<=n;++i)if(anc[i][j-1])
			anc[i][j]=anc[anc[i][j-1]][j-1];
}
inline int dist(int u,int v){
	int g,x=u,y=v;
	if(dep[u]<dep[v])swap(u,v);
	int d=dep[u]-dep[v];
	for(register int i=16;i>=0;--i)if((d>>i)&1)u=anc[u][i];
	if(u==v)g=u;
	else{
		for(register int i=16;i>=0;--i)
			if(anc[u][i]!=anc[v][i])
				u=anc[u][i],v=anc[v][i];
		g=anc[u][0];
	}
	return dis[x]+dis[y]-2*dis[g];
}
inline void csize(int u,int fa){
	sz[u]=1;mx[u]=0;
	for(register int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v])
			csize(v,u),sz[u]+=sz[v],mx[u]=max(mx[u],sz[v]);
	mx[u]=max(mx[u],sum-sz[u]);
	if(mx[u]<mx[rt])rt=u;
}
inline void dfs(int u){
	vis[u]=1;
	for(register int i=head[u],v;i;i=nxt[i])
		if(!vis[v=ver[i]]){
			sum=sz[v];rt=0;
			csize(v,-1);csize(rt,-1);
			fa[rt]=u;dfs(rt);
		}
}
struct Heap{
	priority_queue<int> Q1,Q2;
	void push(int x){Q1.push(x);}
	void remove(int x){if(Q1.size()&&Q1.top()==x)Q1.pop();else Q2.push(x);}
	void chk(){while(Q1.size()&&Q2.size()&&Q1.top()==Q2.top())Q1.pop(),Q2.pop();}
	int top(){chk();return Q1.top();}
	int size(){return Q1.size()-Q2.size();}
	int calc(){int x=top();Q1.pop();int y=top();push(x);return x+y;}
}S1[N],S2[N],S;
int main(){
	cnt=n=read();
	for(register int i=1,u,v,w;i<n;++i){
		u=read();v=read();w=read();
		add(u,v,w);add(v,u,w);
	}
	dfs0(1);pre();
	sum=mx[rt=0]=n;csize(1,-1);csize(rt,-1);dfs(rt);
	for(register int i=1;i<=n;++i)S2[i].push(0);S.push(0);
	for(register int i=1;i<=n;++i)
		for(register int j=i;fa[j];j=fa[j])S1[j].push(dist(i,fa[j]));
	for(register int i=1;i<=n;++i)if(fa[i])S2[fa[i]].push(S1[i].top());
	for(register int i=1;i<=n;++i)if(S2[i].size()>=2)S.push(S2[i].calc());
	m=read()+1;
	while(--m){
		while(ch=getchar(),ch!='A'&&ch!='C');
		if(ch=='A'){
			if(cnt>=2)printf("%d\n",S.top());
			else printf(cnt?"0\n":"They have disappeared.\n");
		}
		else{
			x=read();col[x]^=1;if(!col[x])++cnt;else --cnt;
			if(S2[x].size()>=2)S.remove(S2[x].calc());
			if(!col[x])S2[x].push(0);else S2[x].remove(0);
			if(S2[x].size()>=2)S.push(S2[x].calc());
			for(register int y=x;fa[y];y=fa[y]){
				int t=fa[y];
				if(S2[t].size()>=2)S.remove(S2[t].calc());
				if(S1[y].size())S2[t].remove(S1[y].top());
				if(!col[x])S1[y].push(dist(x,t));
				else S1[y].remove(dist(x,t));
				if(S1[y].size())S2[t].push(S1[y].top());
				if(S2[t].size()>=2)S.push(S2[t].calc());
			}
		}
	}
	return 0;
}
点击查看M题代码
#include<bits/stdc++.h>
using namespace std;
int read(){
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return x*f;
}
const int N=1e5+5;
int n,m,cnt,col[N],head[N],nxt[N<<1],ver[N<<1],etot;int op,x;
void add(int u,int v){ver[++etot]=v;nxt[etot]=head[u];head[u]=etot;}
int sum,rt,sz[N],mx[N],vis[N],fa[N],dep[N],dis[N],anc[N][17];
void dfs0(int u){
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=anc[u][0]){dep[v]=dep[u]+1;anc[v][0]=u;dfs0(v);}
}
void pre(){
	for(int j=1;(1<<j)<=n;++j)
		for(int i=1;i<=n;++i)if(anc[i][j-1])
			anc[i][j]=anc[anc[i][j-1]][j-1];
}
int dist(int u,int v){
	int g,x=u,y=v;
	if(dep[u]<dep[v])swap(u,v);
	for(int i=16,d=dep[u]-dep[v];i>=0;--i)if((d>>i)&1)u=anc[u][i];
	if(u==v)g=u;
	else{
		for(int i=16;i>=0;--i)
			if(anc[u][i]!=anc[v][i])u=anc[u][i],v=anc[v][i];
		g=anc[u][0];
	}
	return dep[x]+dep[y]-2*dep[g];
}
void csize(int u,int fa){
	sz[u]=1;mx[u]=0;
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v])
			csize(v,u),sz[u]+=sz[v],mx[u]=max(mx[u],sz[v]);
	mx[u]=max(mx[u],sum-sz[u]);
	if(mx[u]<mx[rt])rt=u;
}
void dfs(int u){
	vis[u]=1;
	for(int i=head[u],v;i;i=nxt[i])
		if(!vis[v=ver[i]]){
			sum=sz[v];rt=0;
			csize(v,-1);csize(rt,-1);
			fa[rt]=u;dfs(rt);
		}
}
struct Heap{
	priority_queue<int,vector<int>,greater<int> > Q1,Q2;
	void push(int x){Q1.push(x);}
	void remove(int x){if(Q1.size()&&Q1.top()==x)Q1.pop();else Q2.push(x);}
	void chk(){while(Q1.size()&&Q2.size()&&Q1.top()==Q2.top())Q1.pop(),Q2.pop();}
	int top(){chk();return Q1.top();}
	int size(){return Q1.size()-Q2.size();}
}S1[N],S2[N];
int main(){
	cnt=n=read();
	for(int i=1,u,v;i<n;++i){u=read();v=read();add(u,v);add(v,u);}
	dfs0(1);pre();sum=mx[rt=0]=n;csize(1,-1);csize(rt,-1);dfs(rt);
	m=read()+1;
	while(--m){
		op=read();x=read();
		if(op==1){
			if(!cnt){printf("-1\n");continue;}
			if(col[x]){printf("0\n");continue;}
			int ans=1e9+7;if(S2[x].size())ans=min(ans,S2[x].top());
			for(int y=x;fa[y];y=fa[y])if(S2[fa[y]].size()){
				int tmp=S2[fa[y]].top(),tt;
				if(S1[y].size()&&tmp==S1[y].top()){
					if(S2[fa[y]].size()==1)continue;
					tt=tmp;S2[fa[y]].Q1.pop();tmp=S2[fa[y]].top();
				}
				else tt=-1;
				ans=min(ans,tmp+dist(x,fa[y]));
				if(tt!=-1)S2[fa[y]].push(tt);
			}
			printf("%d\n",ans>1e9?-1:ans);
		}
		else{
			col[x]^=1;if(col[x])++cnt;else --cnt;
			if(col[x])S2[x].push(0);else S2[x].remove(0);
			for(int y=x;fa[y];y=fa[y]){
				int t=fa[y];
				if(S1[y].size())S2[t].remove(S1[y].top());
				if(col[x])S1[y].push(dist(x,t));
				else S1[y].remove(dist(x,t));
				if(S1[y].size())S2[t].push(S1[y].top());
			}
		}
	}
	return 0;
}
点击查看N题代码
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+5;
int n,dis[N][2],cnt[N],num,ans;
int head[N],nxt[N<<1],ver[N<<1],val[N<<1],etot;
int sum,rt,sz[N],mx[N],vis[N];
void add(int u,int v,int w){
	ver[++etot]=v;val[etot]=w;
	nxt[etot]=head[u];head[u]=etot;
}
void csize(int u,int fa){
	sz[u]=1;mx[u]=0;
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v])
			csize(v,u),sz[u]+=sz[v],mx[u]=max(mx[u],sz[v]);
	mx[u]=max(mx[u],sum-sz[u]);
	if(mx[u]<mx[rt])rt=u;
}
struct node{int x,y;}a[N];
bool operator <(const node&a,const node&b){
	return a.x==b.x?a.y<b.y:a.x<b.x;
}
void calc(int u,int fa,int d){
	a[++num]=(node){dis[u][0]-d,dis[u][1]};
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v])calc(v,u,d+val[i]);
}
void upd(int u,int fa,int d,int l,int r,int opt){
	cnt[u]+=opt*(r+1-(upper_bound(a+l,a+r+1,(node){d,u})-a));
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v])upd(v,u,d+val[i],l,r,opt);
}
void dfs(int u){
	vis[u]=1;a[num=1]={dis[u][0],dis[u][1]};
	for(int i=head[u],v;i;i=nxt[i])
		if(!vis[v=ver[i]]){
			int n0=num+1;calc(v,u,val[i]);
			sort(a+n0,a+num+1);upd(v,u,val[i],n0,num,-1);
		}
	sort(a+1,a+num+1);upd(u,-1,0,1,num,1);
	for(int i=head[u],v;i;i=nxt[i])
		if(!vis[v=ver[i]]){
			sum=sz[v];rt=0;
			csize(v,-1);csize(rt,-1);dfs(rt);
		}
}
priority_queue<pair<int,int> > Q;
int main(){
	mx[rt]=1e9+7;
	while(scanf("%d",&n)!=EOF){
		ans=etot=rt=0;sum=n;
		for(int i=1;i<=n;i++)
			vis[i]=head[i]=cnt[i]=dis[i][1]=0,dis[i][0]=1e9+7;
		for(int i=1,u,v,w;i<n;i++){
			scanf("%d%d%d",&u,&v,&w);
			add(u,v,w);add(v,u,w);
		}
		for(int i=1,x;i<=n;i++){
			scanf("%d",&x);
			if(x)Q.push(make_pair(0,-i)),dis[i][0]=0,dis[i][1]=i;
		}
		while(!Q.empty()){
			int u=-Q.top().second,d=-Q.top().first;Q.pop();
			if(d!=dis[u][0])continue;
			for(int i=head[u],v;i;i=nxt[i])
				if(dis[v=ver[i]][0]>dis[u][0]+val[i]||
				   dis[v][1]==dis[u][0]+val[i]&&dis[v][1]>dis[u][1]){
					dis[v][0]=dis[u][0]+val[i];
					dis[v][1]=dis[u][1];
					Q.push(make_pair(-dis[v][0],-v));
				}
		}
		csize(1,-1);csize(rt,-1);dfs(rt);
		for(int i=1;i<=n;i++)if(dis[i][0])ans=max(ans,cnt[i]);
		printf("%d\n",ans);
	}
	return 0;
}
点击查看O题代码
#pragma GCC optimize(2)
#pragma GCC optimize(3)
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N=3e4+5;
int n,m,a[N],fa[N];ll s1[N][15][2],s2[N][15][2],cnt[N][15][2],ans;
int head[N],nxt[N<<1],ver[N<<1],val[N<<1],etot;
inline int read(){
	int x=0;char ch=getchar();
	while(ch<'0'||ch>'9')ch=getchar();
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return x;
}
inline void write(ll x){
	if(x>9)write(x/10);
	putchar(x%10+'0');
}
inline void add(int u,int v,int w){
	ver[++etot]=v;val[etot]=w;
	nxt[etot]=head[u];head[u]=etot;
}
int sum,rt,sz[N],mx[N],vis[N],dep[N],dis[N],anc[N][15];
inline void dfs0(int u){
	for(register int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=anc[u][0]){
			dep[v]=dep[u]+1;dis[v]=dis[u]+val[i];
			anc[v][0]=u;dfs0(v);
		}
}
inline void pre(){
	for(register int j=1;(1<<j)<=n;++j)
		for(register int i=1;i<=n;++i)
			anc[i][j]=anc[anc[i][j-1]][j-1];
}
inline int dist(int u,int v){
	int g,x=u,y=v;
	if(dep[u]<dep[v])swap(u,v);
	int d=dep[u]-dep[v];
	for(register int i=13;i>=0;--i)if((d>>i)&1)u=anc[u][i];
	if(u==v)g=u;
	else{
		for(register int i=13;i>=0;--i)
			if(anc[u][i]!=anc[v][i])
				u=anc[u][i],v=anc[v][i];
		g=anc[u][0];
	}
	return dis[x]+dis[y]-2*dis[g];
}
inline void csize(int u,int fa){
	sz[u]=1;mx[u]=0;
	for(register int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa&&!vis[v])
			csize(v,u),sz[u]+=sz[v],mx[u]=max(mx[u],sz[v]);
	mx[u]=max(mx[u],sum-sz[u]);
	if(mx[u]<mx[rt])rt=u;
}
inline void dfs(int u){
	vis[u]=1;
	for(register int i=head[u],v;i;i=nxt[i])
		if(!vis[v=ver[i]]){
			sum=sz[v];rt=0;
			csize(v,-1);csize(rt,-1);
			fa[rt]=u;dfs(rt);
		}
}
inline void modify(int x,int v){
	for(register int i=0;i<14;++i){
		int t1=(a[x]>>i)&1,t2=(v>>i)&1;
		if(t1==t2)continue;
		ll res=0;
		res=res+(s1[x][i][t2^1]-s1[x][i][t1^1]);
		for(register int u=x,d;fa[u];u=fa[u]){
			res=res-(s1[fa[u]][i][t1^1]-s2[u][i][t1^1])
				   +(s1[fa[u]][i][t2^1]-s2[u][i][t2^1])
				   -(cnt[fa[u]][i][t1^1]-cnt[u][i][t1^1])*(d=dist(x,fa[u]))
			       +(cnt[fa[u]][i][t2^1]-cnt[u][i][t2^1])*d;
		}
		ans+=res*(1<<i);
		for(register int u=x;u;u=fa[u]){
			int d1,d2;
			--cnt[u][i][t1];++cnt[u][i][t2];
			s1[u][i][t1]-=(d1=dist(u,x));s1[u][i][t2]+=d1;
			if(fa[u])s2[u][i][t1]-=(d2=dist(fa[u],x)),s2[u][i][t2]+=d2;
		}
	}a[x]=v;
}
int main(){
	mx[rt]=1e9+7;
	while(scanf("%d",&n)!=EOF){
		ans=etot=0;
		for(register int i=1;i<=n;++i){
			vis[i]=head[i]=0;
			for(register int j=0;j<14;++j){
				cnt[i][j][0]=cnt[i][j][1]=0;
				s1[i][j][0]=s1[i][j][1]=0;
				s2[i][j][0]=s2[i][j][1]=0;
			}
		}
		for(register int i=1;i<=n;++i)a[i]=read();
		for(register int i=1,u,v,w;i<n;++i){
			u=read();v=read();w=read();
			add(u,v,w);add(v,u,w);
		}
		dfs0(1);pre();rt=0;sum=n;
		csize(1,-1);csize(rt,-1);fa[rt]=0;dfs(rt);
		for(register int x=1;x<=n;++x){
			for(register int i=0;i<14;++i){
				int t=(a[x]>>i)&1;
				ll res=s1[x][i][t^1];
				for(register int u=x;fa[u];u=fa[u]){
					res=res+(s1[fa[u]][i][t^1]-s2[u][i][t^1])
					       +(cnt[fa[u]][i][t^1]-cnt[u][i][t^1])*dist(x,fa[u]);
				}
				ans+=res*(1<<i);
				for(register int u=x,d1,d2;u;u=fa[u]){
					++cnt[u][i][t];s1[u][i][t]+=dist(u,x);
					if(fa[u])s2[u][i][t]+=dist(fa[u],x);
				}
			}
		}
		m=read();
		for(register int i=1,x,v;i<=n;++i){
			x=read();v=read();modify(x,v);
			write(ans);putchar('\n');
		}
	}
	return 0;
}
posted @ 2023-05-11 22:13  by_chance  阅读(24)  评论(0编辑  收藏  举报