【XSY1602】安全网络 树形DP 数学

题目大意

  有一颗树,要为每个节点赋一个值\(l_i\leq a_i\leq r_i\),使得任意相邻的节点互素。然后对每个节点统计\(a_i\)在所有可能的情况中的和。

  \(n\leq 50,1\leq l_i\leq r_i\leq m,m=50000\)

题解

  设\(f_{i,j}\)为以\(i\)为根的子树都赋了值后\(a_i=j\)的方案数。那么怎么处理\(f_v\)\(f_i\)的贡献呢?

\[g_i=\mu(i)\sum_{i|j}f_{v,j}\\ f_{x,i}\times=\sum_{j|i}g_j \]

  \(f_{v,i}\)\(f_{x,j}\)的贡献是\(\sum_{k|(i,j)}\mu(k)f_{v,i}\)。因为\(\sum_{d|n}\mu(d)=[n=1]\),所以只有\(\gcd(i,j)=1\)\(f_{v,i}\)\(f_{x,j}\)有贡献。

  设\(h_{i,j}\)为整棵树都赋了值后\(a_i=j\)的方案数。我们发现,\(h_v\)是把\(h_x\)减去\(f_v\)后再加到\(f_v\)上。用逆元搞一搞即可。

  然后就是愉快的卡常时间了。

  时间复杂度:\(O(nm\log m)\)

代码

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<list>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
ll p=1000000007;
int m=50000;
ll fp(ll a,ll b)
{
	ll s=1;
	while(b)
	{
		if(b&1)
			s=s*a%p;
		a=a*a%p;
		b>>=1;
	}
	return s;
}
ll exgcd(ll a,ll b,ll &x,ll &y)
{
	if(!b)
	{
		x=1;
		y=0;
		return a;
	}
	ll ab=a/b;
	ll c=a-b*ab;
	ll d=exgcd(b,c,y,x);
	y-=x*ab;
	return d;
}
namespace prime
{
	int cnt;
	int b[100010];
	int p[100010];
	int u[100010];
	void init()
	{
		cnt=0;
		memset(b,0,sizeof b);
		int i,j;
		u[1]=1;
		for(i=2;i<=m;i++)
		{
			if(!b[i])
			{
				p[++cnt]=i;
				u[i]=-1;
			}
			for(j=1;j<=cnt&&i*p[j]<=m;j++)
			{
				b[i*p[j]]=1;
				if(i%p[j]==0)
				{
					u[i*p[j]]=0;
					break;
				}
				u[i*p[j]]=-u[i];
			}
		}
	}
};
list<int> li[60];
int l[60];
int r[60];
ll ans[60];
ll f[60][50010];
ll g[60][50010];
ll c[50010];
ll d[50010];
ll e[50010];
void p0(int x)
{
	while(x--)
		printf(" 0");
	printf("\n");
}
void calc()
{
	memset(d,0,sizeof d);
	memset(e,0,sizeof e);
	int i,j;
	for(i=1;i<=m;i++)
		if(prime::u[i])
		{
			for(j=i;j<=m;j+=i)
				e[i]+=c[j];
			e[i]%=p;
			e[i]*=prime::u[i];
		}
	for(i=1;i<=m;i++)
		if(e[i])
			for(j=i;j<=m;j+=i)
				d[j]+=e[i];
	for(i=1;i<=m;i++)
		d[i]%=p;
}
void dfs1(int x,int fa)
{
	int i;
	for(i=l[x];i<=r[x];i++)
		f[x][i]=1;
	for(auto v:li[x])
		if(v!=fa)
		{
			dfs1(v,x);
			memcpy(c,f[v],sizeof f[v]);
			calc();
			for(i=l[x];i<=r[x];i++)
			{
				g[v][i]=d[i];
				f[x][i]=f[x][i]*d[i]%p;
			}
		}
}
void dfs2(int x,int fa)
{
	int i;
	if(fa)
	{
//		memcpy(c,f[x],sizeof f[x]);
//		calc();
		ll v1,v2;
		for(i=1;i<=m;i++)
//			if(f[fa][i]&&d[i])
//			{
////				c[i]=f[fa][i]*fp(d[i],p-2)%p;
//				int gcd=exgcd(d[i],p,v1,v2);
//				if(gcd==-1)
//					v1=-v1;
//				c[i]=f[fa][i]*v1%p;
//			}
			if(f[fa][i]&&g[x][i])
			{
				int gcd=exgcd(g[x][i],p,v1,v2);
				if(gcd==-1)
					v1=-v1;
				c[i]=f[fa][i]*v1%p;
			}
			else
				c[i]=0;
		calc();
		for(i=l[x];i<=r[x];i++)
			f[x][i]=f[x][i]*d[i]%p;
	}
	for(i=l[x];i<=r[x];i++)
	{
		f[x][i]=(f[x][i]%p+p)%p;
		ans[x]=(ans[x]+f[x][i]*i%p)%p;
	}
	for(auto v:li[x])
		if(v!=fa)
			dfs2(v,x);
}
void solve()
{
	int n;
	scanf("%d",&n);
	int i;
	for(i=1;i<=n;i++)
		scanf("%d",&l[i]);
	for(i=1;i<=n;i++)
		scanf("%d",&r[i]);
	for(i=1;i<=n;i++)
		li[i].clear();
	int x,y;
	for(i=1;i<n;i++)
	{
		scanf("%d%d",&x,&y);
		li[x].push_back(y);
		li[y].push_back(x);
	}
	for(i=1;i<=n;i++)
		if(l[i]>r[i])
		{
			p0(n);
			return;
		}
	memset(g,0,sizeof g);
	memset(f,0,sizeof f);
	memset(ans,0,sizeof ans);
	dfs1(1,0);
	dfs2(1,0);
	for(i=1;i<=n;i++)
		printf("%lld ",ans[i]);
	printf("\n");
}
int main()
{
	freopen("b.in","r",stdin);
	freopen("b.out","w",stdout);
	int t;
	prime::init();
	scanf("%d",&t);
	while(t--)
		solve();
	return 0;
}
posted @ 2018-03-05 21:01  ywwyww  阅读(216)  评论(0编辑  收藏  举报