CF809E Surprise me!

【题意】

 

 【分析】

看到这个$\phi(a_i*a_j)$的格式,可以把它转换成$\phi(x*y)=\frac{\phi(x)*\phi(y)*gcd(x,y)}{\phi(gcd(x,y))}$

先不考虑前的系数,我们的所求就可以转换为$$\sum_{i=1}^{n}\sum_{j=1}^{n}\frac{\phi(a[i])*\phi(a[j])*gcd(a[i],a[j])}{\phi(gcd(a[i],a[j]))}*dist(i,j)$$

然后就是枚举一下gcd $$\sum_{d=1}^{n}\frac{d}{\phi(d)}\sum_{i=1}^{n}\sum_{j=1}^{n}\phi(a[i])*\phi(a[j])*[gcd(a[i],a[j])=d]*dist(i,j)$$

设$f(d)=\sum_{i=1}^{n}\sum_{j=1}^{n}\phi(a[i])*\phi(a[j])*[gcd(a[i],a[j])=d]*dist(i,j)$

$F(d)=\sum_{i=1}^{n}\sum_{j=1}^{n}\phi(a[i])*\phi(a[j])*[d|gcd(a[i],a[j])]*dist(i,j)$

得到$F(x)=\sum_{x|d}f(d) \Rightarrow f(x)=\sum_{x|d}\mu(\frac{d}{x})F(d)$

$ans=\sum_{d=1}^{n}\frac{d}{\phi(d)}f(x) \Rightarrow ans=\sum_{d=1}^{n}\frac{d}{\phi(d)}\sum_{d|k}\mu(\frac{k}{d})F(k)$

我们可以把除了$F(x)$以外的项拿出来,先预处理一下,方便后面的计算

$tmp[d]=\frac{d}{\phi(d)}\sum_{d|k}\mu(\frac{k}{d})$

然后就是求$F(x)$了,把x的倍数的点全部拿出来建虚树,然后做一下树形dp计算即可

树形dp来计算整个虚树内的$\phi(a)*\phi(b)*dis(a,b)$,考虑每条边的贡献是$ne[i].v*(sumtot-sumOfu)*sumOfu$

这里的sumtot是整个树的$\phi$,sumOfu是子树内的$\sum_{v\in u}\phi(v)$

 

 

代码实现起来细节较多,比如虚树的清空问题,还有虚树注意只有真实需要的点有点权,那些lca没有点权,还有各种取模

【代码】

 

#include<bits/stdc++.h>
using namespace std;
const int maxn=4e5+5;
typedef long long ll;
const ll mod=1e9+7;
int phi[maxn],mu[maxn],p[maxn],np[maxn],cntp;
int head[maxn],a[maxn],tot,n,rv[maxn],s[maxn],top,point[maxn],num;
ll invphi[maxn];
struct edge
{
    int to,nxt;
}e[maxn<<1];
void init()
{
    phi[1]=mu[1]=1;
    for(int i=2;i<=n;i++)
    {
        if(!np[i])
        {
            p[++cntp]=i;
            mu[i]=-1; phi[i]=i-1;
        }
        for(int j=1;p[j]*i<=n && j<=cntp;j++)
        {
            np[i*p[j]]=1;
            if(i%p[j]==0)
            {
                phi[i*p[j]]=phi[i]*p[j];
                mu[i*p[j]]=0;
                break;
            }
            else
            {
                phi[i*p[j]]=phi[i]*(p[j]-1);
                mu[i*p[j]]=-mu[i];
            }
        }
    }
}
void add(int x,int y)
{
    e[++tot].to=y; e[tot].nxt=head[x]; head[x]=tot;
}
int dep[maxn],st[maxn],ed[maxn],dfstime,euler[maxn<<1],mn[maxn<<1][30],lg[maxn<<1];
void dfs(int u,int fa)
{
    dep[u]=dep[fa]+1; euler[++dfstime]=u; st[u]=dfstime;
    for(int i=head[u];i;i=e[i].nxt)
    {
        int to=e[i].to;
        if(to==fa) continue;
        dfs(to,u);
        euler[++dfstime]=u;
    }
    ed[u]=dfstime;
}
void lca_init()
{
    lg[0]=-1;
    for(int i=1;i<=dfstime;i++) lg[i]=lg[i>>1]+1;
    for(int i=1;i<=dfstime;i++) mn[i][0]=euler[i];
    for(int j=1;(1<<j)<=dfstime;j++)
        for(int i=1;i+(1<<j)-1<=dfstime;i++)
        {
            int k=i+(1<<(j-1));
            if(dep[mn[i][j-1]]<dep[mn[k][j-1]])
                mn[i][j]=mn[i][j-1];
            else mn[i][j]=mn[k][j-1];
        }
}
int getlca(int x,int y)
{
    int l=st[x],r=ed[y];
    if(l>r) l=st[y],r=ed[x];
    int i=lg[r-l+1],t=r-(1<<i)+1;
    return dep[mn[l][i]]<dep[mn[t][i]]?mn[l][i]:mn[t][i];
}
int calcdis(int x,int y)
{
    return dep[x]+dep[y]-dep[getlca(x,y)]*2;
}
ll qpow(ll a,ll b)
{
    ll res=1;
    while(b)
    {
        if(b&1) res=(res*a)%mod;
        b>>=1;
        a=(a*a)%mod;
    }
    return res;
}
ll tmp[maxn];
bool cmp(int a,int b)
{
    return st[a]<st[b];
}
int h[maxn],ecnt;
struct Edge
{
    int to,nxt,v;
}ne[maxn<<1];
void addedge(int x,int y,int z)
{
    ne[++ecnt].to=y; ne[ecnt].nxt=h[x]; ne[ecnt].v=z; h[x]=ecnt;
}
void build()
{
    ecnt=0;
    sort(point+1,point+num+1,cmp);
    s[top=1]=point[1];
    for(int i=2;i<=num;i++)
    {
        int z=getlca(s[top],point[i]);
        while(dep[s[top-1]]>dep[z])
        {
            addedge(s[top-1],s[top],calcdis(s[top-1],s[top]));
            top--;
        }
        if(s[top]!=z)
        {
            addedge(z,s[top],calcdis(s[top],z));
            if(s[top-1]==z) top--;
            else s[top]=z;
        }
        s[++top]=point[i];
    }
    while(--top) addedge(s[top],s[top+1],calcdis(s[top],s[top+1]));
}
ll val[maxn],ress,sumtot;
ll dfsdp(int u,int fa)
{
    ll sumphi=val[u];
    for(int i=h[u];i;i=ne[i].nxt)
    {
        int to=ne[i].to;
        if(to==fa) continue;
        ll temp=dfsdp(to,u);
        sumphi=(sumphi+temp)%mod;
        ress=(ress+(ne[i].v*1LL*temp%mod)*(sumtot-temp))%mod;
    }    
    h[u]=0; val[u]=0;
    return sumphi;
}
int main()
{
    freopen("a.in","r",stdin);
    freopen("a.out","w",stdout);
    scanf("%d",&n);
    init();
    for(int i=1;i<=n;i++) scanf("%d",&a[i]),rv[a[i]]=i;
    int x,y;
    for(int i=1;i<n;i++)
    {
        scanf("%d%d",&x,&y);
        add(x,y); add(y,x);
    }
    dfs(1,0);
    lca_init();
    for(int i=1;i<=n;i++) invphi[i]=qpow(phi[i],mod-2);
    for(int d=1;d<=n;d++)
        for(int i=d;i<=n;i+=d)
            tmp[i]=((tmp[i]+(1LL*d*invphi[d]%mod)*mu[i/d]%mod)+mod)%mod;
    ll ans=0;
    for(int d=1;d<=n;d++)
    {
        num=0; sumtot=0;
        for(int i=d;i<=n;i+=d) point[++num]=rv[i],val[rv[i]]=phi[i],sumtot=(sumtot+phi[i])%mod;
        build();ress=0;
        dfsdp(s[1],0);
        ress=ress*2%mod;
        ans=(ans+(ress*tmp[d]%mod))%mod;
        if(ans<0) ans+=mod;
    }
    ans=ans*qpow(n,mod-2)%mod*qpow(n-1,mod-2)%mod;
    if(ans<0) ans+=mod;
    printf("%lld\n",ans);
    return 0;
}

 

posted @ 2021-05-21 08:48  andyc_03  阅读(64)  评论(0)    收藏  举报