【CF434E】Furukawa Nagisa's Tree 点分治

【CF434E】Furukawa Nagisa's Tree

题意:一棵n个点的树,点有点权。定义$G(a,b)$表示:我们将树上从a走到b经过的点都拿出来,设这些点的点权分别为$z_0,z_1...z_{l-1}$,则$G(a,b)=z_0+z_1k^1+z_2k^2+...+z_{l-1}k^{l-1}$。如果$G(a,b)=X \mod Y$(保证Y是质数),则我们称(a,b)是好的,否则是坏的。现在想知道,有多少个三元组(a,b,c),满足(a,b),(b,c),(a,c)都是好的或者都是坏的?

$n\le 10^5,Y\le 10^9$

题解:由于一个点对要么是好的要么是坏的,所以我们可以枚举一下所有符合条件的3元组的情况。不过符合条件需要3条边都相同,那我们可以反过来,统计不合法的3元组的情况(一共$2^3-2$种情况)。经过观察我们发现,我们可以在 同时连接两种颜色的边 的那个点处统计贡献,即把三元组的贡献放到了点上。我们设$in_0(),in_1(i),out_0(i),out_1(i)$表示i有多少个好(坏)边连入(出),则一个点对答案的贡献就变成:

$2in_0(i)in_1(i)+2out_0(i)out_1(i)+in_0(i)out_1(i)+in_1(i)out_0(i)$

最后将答案/2即可。

所以现在我们只需要求:对于每个点,有多少好边连入(连出)。这个用点分治可以搞定,因为我们容易计算两个多项式连接起来的结果。本题我采用的是容斥式的点分治。

 

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int maxn=100010;
typedef long long ll;
int n,cnt,tot,mn,rt;
ll X,Y,K,Ki,ans;
ll pw[maxn],pi[maxn],v[maxn],in1[maxn],in0[maxn],out1[maxn],out0[maxn];
int to[maxn<<1],nxt[maxn<<1],head[maxn],vis[maxn],siz[maxn];
struct node
{
	ll x;
	int y;
	node() {}
	node(ll a,int b) {x=a,y=b;}
	bool operator < (const node &a) const {return x<a.x;}
}p[maxn],q[maxn];
inline int rd()
{
	char gc=getchar();	int ret=0;
	while(gc<'0'||gc>'9')	gc=getchar();
	while(gc>='0'&&gc<='9')	ret=ret*10+gc-'0',gc=getchar();
	return ret;
}
inline void add(int a,int b)
{
	to[cnt]=b,nxt[cnt]=head[a],head[a]=cnt++;
}
inline ll pm(ll x,ll y)
{
	ll z=1;
	while(y)
	{
		if(y&1)	z=z*x%Y;
		x=x*x%Y,y>>=1;
	}
	return z;
}
void getrt(int x,int fa)
{
	int i,tmp=0;
	siz[x]=1;
	for(i=head[x];i!=-1;i=nxt[i])	if(!vis[to[i]]&&to[i]!=fa)	getrt(to[i],x),siz[x]+=siz[to[i]],tmp=max(tmp,siz[to[i]]);
	tmp=max(tmp,n-siz[x]);
	if(tmp<mn)	mn=tmp,rt=x;
}
void getp(int x,int fa,int dep,ll s1,ll s2)
{
	s1=(s1*K+v[x])%Y,s2=(s2+v[x]*((!dep)?0:pw[dep-1]))%Y,dep++;
	p[++tot]=node((X-s1+Y)*pi[dep]%Y,x),q[tot]=node(s2,x);
	for(int i=head[x];i!=-1;i=nxt[i])	if(!vis[to[i]]&&to[i]!=fa)
		getp(to[i],x,dep,s1,s2);
}
void calc(int x,int flag,int dep,ll s1,ll s2)
{
	int i,j,cnt;
	tot=0;
	s1=(s1*K+v[x])%Y,s2=(s2+v[x]*((!dep)?0:pw[dep-1]))%Y,dep++;
	p[++tot]=node((X-s1+Y)*pi[dep]%Y,x),q[tot]=node(s2,x);
	for(i=head[x];i!=-1;i=nxt[i])	if(!vis[to[i]])	getp(to[i],x,dep,s1,s2);
	sort(p+1,p+tot+1),sort(q+1,q+tot+1);
	for(cnt=0,i=j=1;i<=tot;i++)
	{
		for(;j<=tot&&q[j].x<=p[i].x;j++)
		{
			if(j==1||q[j].x!=q[j-1].x)	cnt=0;
			cnt++;
		}
		if(j!=1&&q[j-1].x==p[i].x)	out1[p[i].y]+=cnt*flag;
	}
	for(cnt=0,i=j=1;i<=tot;i++)
	{
		for(;j<=tot&&p[j].x<=q[i].x;j++)
		{
			if(j==1||p[j].x!=p[j-1].x)	cnt=0;
			cnt++;
		}
		if(j!=1&&p[j-1].x==q[i].x)	in1[q[i].y]+=cnt*flag;
	}
}
void dfs(int x)
{
	vis[x]=1;
	int i;
	calc(x,1,0,0,0);
	for(i=head[x];i!=-1;i=nxt[i])	if(!vis[to[i]])
	{
		calc(to[i],-1,1,v[x],0);
		tot=siz[to[i]],mn=1<<30,getrt(to[i],x),dfs(rt);
	}
}
int main()
{
	//freopen("cf434E.in","r",stdin);
	
	n=rd(),Y=rd(),K=rd(),X=rd(),Ki=pm(K,Y-2);
	int i,a,b;
	memset(head,-1,sizeof(head));
	for(i=1;i<=n;i++)	v[i]=rd();
	for(i=pw[0]=pi[0]=1;i<=n;i++)	pw[i]=pw[i-1]*K%Y,pi[i]=pi[i-1]*Ki%Y;
	for(i=1;i<n;i++)	a=rd(),b=rd(),add(a,b),add(b,a);
	tot=n,mn=1<<30,getrt(1,0),dfs(rt);
	for(i=1;i<=n;i++)
	{
		in0[i]=n-in1[i],out0[i]=n-out1[i];
		ans+=2*in1[i]*in0[i]+2*out1[i]*out0[i]+in0[i]*out1[i]+in1[i]*out0[i];
	}
	printf("%lld",1ll*n*n*n-ans/2);
	return 0;
}

 

posted @ 2018-04-05 17:08 CQzhangyu 阅读(...) 评论(...) 编辑 收藏