CF809E Surprise me!
【题意】

【分析】
看到这个$\phi(a_i*a_j)$的格式,可以把它转换成$\phi(x*y)=\frac{\phi(x)*\phi(y)*gcd(x,y)}{\phi(gcd(x,y))}$
先不考虑前的系数,我们的所求就可以转换为$$\sum_{i=1}^{n}\sum_{j=1}^{n}\frac{\phi(a[i])*\phi(a[j])*gcd(a[i],a[j])}{\phi(gcd(a[i],a[j]))}*dist(i,j)$$
然后就是枚举一下gcd $$\sum_{d=1}^{n}\frac{d}{\phi(d)}\sum_{i=1}^{n}\sum_{j=1}^{n}\phi(a[i])*\phi(a[j])*[gcd(a[i],a[j])=d]*dist(i,j)$$
设$f(d)=\sum_{i=1}^{n}\sum_{j=1}^{n}\phi(a[i])*\phi(a[j])*[gcd(a[i],a[j])=d]*dist(i,j)$
$F(d)=\sum_{i=1}^{n}\sum_{j=1}^{n}\phi(a[i])*\phi(a[j])*[d|gcd(a[i],a[j])]*dist(i,j)$
得到$F(x)=\sum_{x|d}f(d) \Rightarrow f(x)=\sum_{x|d}\mu(\frac{d}{x})F(d)$
$ans=\sum_{d=1}^{n}\frac{d}{\phi(d)}f(x) \Rightarrow ans=\sum_{d=1}^{n}\frac{d}{\phi(d)}\sum_{d|k}\mu(\frac{k}{d})F(k)$
我们可以把除了$F(x)$以外的项拿出来,先预处理一下,方便后面的计算
$tmp[d]=\frac{d}{\phi(d)}\sum_{d|k}\mu(\frac{k}{d})$
然后就是求$F(x)$了,把x的倍数的点全部拿出来建虚树,然后做一下树形dp计算即可
树形dp来计算整个虚树内的$\phi(a)*\phi(b)*dis(a,b)$,考虑每条边的贡献是$ne[i].v*(sumtot-sumOfu)*sumOfu$
这里的sumtot是整个树的$\phi$,sumOfu是子树内的$\sum_{v\in u}\phi(v)$
代码实现起来细节较多,比如虚树的清空问题,还有虚树注意只有真实需要的点有点权,那些lca没有点权,还有各种取模
【代码】
#include<bits/stdc++.h> using namespace std; const int maxn=4e5+5; typedef long long ll; const ll mod=1e9+7; int phi[maxn],mu[maxn],p[maxn],np[maxn],cntp; int head[maxn],a[maxn],tot,n,rv[maxn],s[maxn],top,point[maxn],num; ll invphi[maxn]; struct edge { int to,nxt; }e[maxn<<1]; void init() { phi[1]=mu[1]=1; for(int i=2;i<=n;i++) { if(!np[i]) { p[++cntp]=i; mu[i]=-1; phi[i]=i-1; } for(int j=1;p[j]*i<=n && j<=cntp;j++) { np[i*p[j]]=1; if(i%p[j]==0) { phi[i*p[j]]=phi[i]*p[j]; mu[i*p[j]]=0; break; } else { phi[i*p[j]]=phi[i]*(p[j]-1); mu[i*p[j]]=-mu[i]; } } } } void add(int x,int y) { e[++tot].to=y; e[tot].nxt=head[x]; head[x]=tot; } int dep[maxn],st[maxn],ed[maxn],dfstime,euler[maxn<<1],mn[maxn<<1][30],lg[maxn<<1]; void dfs(int u,int fa) { dep[u]=dep[fa]+1; euler[++dfstime]=u; st[u]=dfstime; for(int i=head[u];i;i=e[i].nxt) { int to=e[i].to; if(to==fa) continue; dfs(to,u); euler[++dfstime]=u; } ed[u]=dfstime; } void lca_init() { lg[0]=-1; for(int i=1;i<=dfstime;i++) lg[i]=lg[i>>1]+1; for(int i=1;i<=dfstime;i++) mn[i][0]=euler[i]; for(int j=1;(1<<j)<=dfstime;j++) for(int i=1;i+(1<<j)-1<=dfstime;i++) { int k=i+(1<<(j-1)); if(dep[mn[i][j-1]]<dep[mn[k][j-1]]) mn[i][j]=mn[i][j-1]; else mn[i][j]=mn[k][j-1]; } } int getlca(int x,int y) { int l=st[x],r=ed[y]; if(l>r) l=st[y],r=ed[x]; int i=lg[r-l+1],t=r-(1<<i)+1; return dep[mn[l][i]]<dep[mn[t][i]]?mn[l][i]:mn[t][i]; } int calcdis(int x,int y) { return dep[x]+dep[y]-dep[getlca(x,y)]*2; } ll qpow(ll a,ll b) { ll res=1; while(b) { if(b&1) res=(res*a)%mod; b>>=1; a=(a*a)%mod; } return res; } ll tmp[maxn]; bool cmp(int a,int b) { return st[a]<st[b]; } int h[maxn],ecnt; struct Edge { int to,nxt,v; }ne[maxn<<1]; void addedge(int x,int y,int z) { ne[++ecnt].to=y; ne[ecnt].nxt=h[x]; ne[ecnt].v=z; h[x]=ecnt; } void build() { ecnt=0; sort(point+1,point+num+1,cmp); s[top=1]=point[1]; for(int i=2;i<=num;i++) { int z=getlca(s[top],point[i]); while(dep[s[top-1]]>dep[z]) { addedge(s[top-1],s[top],calcdis(s[top-1],s[top])); top--; } if(s[top]!=z) { addedge(z,s[top],calcdis(s[top],z)); if(s[top-1]==z) top--; else s[top]=z; } s[++top]=point[i]; } while(--top) addedge(s[top],s[top+1],calcdis(s[top],s[top+1])); } ll val[maxn],ress,sumtot; ll dfsdp(int u,int fa) { ll sumphi=val[u]; for(int i=h[u];i;i=ne[i].nxt) { int to=ne[i].to; if(to==fa) continue; ll temp=dfsdp(to,u); sumphi=(sumphi+temp)%mod; ress=(ress+(ne[i].v*1LL*temp%mod)*(sumtot-temp))%mod; } h[u]=0; val[u]=0; return sumphi; } int main() { freopen("a.in","r",stdin); freopen("a.out","w",stdout); scanf("%d",&n); init(); for(int i=1;i<=n;i++) scanf("%d",&a[i]),rv[a[i]]=i; int x,y; for(int i=1;i<n;i++) { scanf("%d%d",&x,&y); add(x,y); add(y,x); } dfs(1,0); lca_init(); for(int i=1;i<=n;i++) invphi[i]=qpow(phi[i],mod-2); for(int d=1;d<=n;d++) for(int i=d;i<=n;i+=d) tmp[i]=((tmp[i]+(1LL*d*invphi[d]%mod)*mu[i/d]%mod)+mod)%mod; ll ans=0; for(int d=1;d<=n;d++) { num=0; sumtot=0; for(int i=d;i<=n;i+=d) point[++num]=rv[i],val[rv[i]]=phi[i],sumtot=(sumtot+phi[i])%mod; build();ress=0; dfsdp(s[1],0); ress=ress*2%mod; ans=(ans+(ress*tmp[d]%mod))%mod; if(ans<0) ans+=mod; } ans=ans*qpow(n,mod-2)%mod*qpow(n-1,mod-2)%mod; if(ans<0) ans+=mod; printf("%lld\n",ans); return 0; }

浙公网安备 33010602011771号