Codeforces 809E - Surprise me!(虚树+莫比乌斯反演)

Codeforces 题目传送门 & 洛谷题目传送门

1A,就 nm 爽(

首先此题一个很棘手的地方在于贡献的计算式中涉及 \(\varphi(a_ia_j)\),而这东西与 \(i,j\) 都有关,无法拆开来计算,因此无法独立考虑 \(i,j\) 的贡献。因此我们要想方设法把这里面的 \(a_ia_j\) 拆开来,我们考虑探究 \(\varphi(a_ia_j)\)\(\varphi(a_i),\varphi(a_j)\) 有什么关系,很容易发现一个性质,那就是 \(\varphi(a_ia_j)=\dfrac{\varphi(a_i)\varphi(a_j)\text{gcd}(a_i,a_j)}{\varphi(\text{gcd}(a_i,a_j))}\)(提示:考虑重复质因子的贡献),故 \(ans=\sum\limits_{i=1}^n\sum\limits_{j=1}^n\dfrac{\varphi(a_i)\varphi(a_j)\text{gcd}(a_i,a_j)}{\varphi(\text{gcd}(a_i,a_j))}\text{dist}(i,j)\)。这里涉及 \(\text{gcd}\),可以套路地想到莫比乌斯反演,我们考虑枚举 \(\text{gcd}(a_i,a_j)\),那么 \(ans=\sum\limits_{d=1}^n\dfrac{d}{\varphi(d)}\sum\limits_{i=1}^n\sum\limits_{j=1}^n\varphi(a_i)\varphi(a_j)\text{dist}(i,j)[\text{gcd}(a_i,a_j)=d]\),按照莫比乌斯反演的套路我们设 \(f(d)=\varphi(a_i)\varphi(a_j)\text{dist}(i,j)[\text{gcd}(a_i,a_j)=d]\),再设 \(g(d)=\varphi(a_i)\varphi(a_j)\text{dist}(i,j)[d\mid\text{gcd}(a_i,a_j)]\),那么显然有 \(g(d)=\sum\limits_{d\mid n}f(n)\),反演以下可得 \(f(d)=\sum\limits_{d\mid n}g(n)\mu(\dfrac{n}{d})\),因此我们只需求出 \(g(n)\) 就可求出 \(f(d)\),也就顺带着能够求出答案了。

接下来考虑怎样求 \(f(d)\),由于 \(d\mid\gcd(a_i,a_j)\),必然也有 \(d\mid a_i,d\mid a_j\),我们考虑设 \(S=\{i|d\mid a_i\}\),那么显然等价于求 \(\sum\limits_{x,y\in S}\varphi(a_x)\varphi(a_y)\text{dist}(x,y)\),下设 \(b_i=\varphi(a_i),z=\text{lca}(x,y)\)\(d_i\) 表示 \(i\) 的深度,那么 \(f(d)=\sum\limits_{x,y\in S}b_xb_y(d_x+d_y-2d_z)\),把括号打开即可得到 \(\sum\limits_{x,y\in S}b_xb_yd_x+\sum\limits_{x,y\in S}b_xb_yd_y-2\sum\limits_{x,y\in S}b_xb_yd_z\),左边两项显然是相等的,我们只需算出 \(s_1=\sum\limits_{x\in S}b_xd_x,s_2=\sum\limits_{x\in S}b_y\),那么 \(\sum\limits_{x,y\in S}b_xb_yd_x=s_1s_2\),关键在于第三项怎么求,我们考虑对 \(S\) 建立虚树并在 \(z\) 处统计答案,记 \(dp_x\)\(x\) 子树中所有关键点的 \(b\) 值之和,那么我们显然可以像启发式合并/点分治一样在合并子树的 \(dp\) 值的过程中统计答案。

由于 \(a_i\) 是一个 \(1\sim n\) 的排列,因此所有点出现在 \(S\) 中的次数之和为 \(n\ln n\),即 \(\sum|S|=n\ln n\),再加上建立虚树的复杂度 \(n\log n\),故该算法复杂度为 \(n\ln n\log n\)

细节有亿点点多啊……

const int MAXN=2e5;
const int LOG_N=19;
const int MOD=1e9+7;
int qpow(int x,int e=MOD-2){
	int ret=1;
	for(;e;e>>=1,x=1ll*x*x%MOD) if(e&1) ret=1ll*ret*x%MOD;
	return ret;
}
int pr[MAXN/5+5],vis[MAXN+5],phi[MAXN+5],mu[MAXN+5],prcnt=0;
void sieve(int n){
	phi[1]=mu[1]=1;
	for(int i=2;i<=n;i++){
		if(!vis[i]){pr[++prcnt]=i;phi[i]=i-1;mu[i]=-1;}
		for(int j=1;j<=prcnt&&i*pr[j]<=n;j++){
			vis[i*pr[j]]=1;
			if(i%pr[j]==0){phi[i*pr[j]]=phi[i]*pr[j];break;}
			else phi[i*pr[j]]=phi[i]*phi[pr[j]],mu[i*pr[j]]=-mu[i];
		}
	}
}
int n,a[MAXN+5],pos[MAXN+5],f[MAXN+5],res[MAXN+5];
int hd[MAXN+5],to[MAXN*2+5],nxt[MAXN*2+5],ec=0;
void adde(int u,int v){to[++ec]=v;nxt[ec]=hd[u];hd[u]=ec;}
int seq[MAXN*2+5],dfn[MAXN+5],tim=0,tim_dfn=0,dep[MAXN+5],bgt[MAXN+5];
pii st[MAXN*2+5][LOG_N+2];
void dfs(int x,int f){
	bgt[x]=++tim_dfn;
	for(int e=hd[x];e;e=nxt[e]){
		int y=to[e];if(y==f) continue;
		dep[y]=dep[x]+1;dfs(y,x);seq[++tim]=x;
	} seq[++tim]=x;dfn[x]=tim;
}
pii querymn(int l,int r){
	int k=31-__builtin_clz(r-l+1);
	return min(st[l][k],st[r-(1<<k)+1][k]);
}
int getlca(int x,int y){
	if(dfn[x]>dfn[y]) x^=y^=x^=y;
	return querymn(dfn[x],dfn[y]).se;
}
int pt[MAXN+5],pcnt=0,stk[MAXN+5],tp=0;
bool in[MAXN+5];vector<pii> g[MAXN+5];
void adde_vir(int u,int v){g[u].pb(mp(v,dep[v]-dep[u]));}
void insert(int x){
	if(!tp){stk[++tp]=x;return;}
	int lca=getlca(x,stk[tp]);
	while(tp>=2&&dep[lca]<dep[stk[tp-1]]){adde_vir(stk[tp-1],stk[tp]);tp--;}
	if(tp&&dep[lca]<dep[stk[tp]]) adde_vir(lca,stk[tp--]);
	if(!tp||stk[tp]!=lca) stk[++tp]=lca;
	stk[++tp]=x;
}
void fin(){
	while(tp>=2) adde_vir(stk[tp-1],stk[tp]),tp--;
	tp=0;
}
int ret=0,dp[MAXN+5],dis[MAXN+5];
void dfs_vir(int x){
	dp[x]=0;
	for(int i=0;i<g[x].size();i++){
		int y=g[x][i].fi,z=g[x][i].se;
		dis[y]=dis[x]+z;dfs_vir(y);
		ret=(ret-2ll*dp[x]*dp[y]%MOD*dis[x]%MOD+MOD)%MOD;
		dp[x]=(dp[x]+dp[y])%MOD;
	} if(in[x]){
		ret=(ret-1ll*dp[x]*phi[a[x]]%MOD*dis[x]%MOD+MOD)%MOD;
		dp[x]=(dp[x]+phi[a[x]])%MOD;
		ret=(ret-1ll*dp[x]*phi[a[x]]%MOD*dis[x]%MOD+MOD)%MOD;
	}
}
void clear(int x){
	dp[x]=dis[x]=0;
	for(int i=0;i<g[x].size();i++) clear(g[x][i].fi);
	while(g[x].size()) g[x].pop_back();
}
int solve(int x){
	pcnt=0;
	for(int i=x;i<=n;i+=x) pt[++pcnt]=pos[i],in[pos[i]]=1;
	sort(pt+1,pt+pcnt+1,[](int x,int y){return bgt[x]<bgt[y];});
	if(!in[1]) insert(1);for(int i=1;i<=pcnt;i++) insert(pt[i]);fin();
	ret=0;dfs_vir(1);int s1=0,s2=0;
	for(int i=1;i<=pcnt;i++){
		s1=(s1+1ll*phi[a[pt[i]]]*dis[pt[i]])%MOD;
		s2=(s2+phi[a[pt[i]]])%MOD;
	}
	ret=(ret+1ll*s1*s2)%MOD;ret=ret*2%MOD;clear(1);
	for(int i=x;i<=n;i+=x) in[pos[i]]=0;
	return ret;
}
int main(){
	scanf("%d",&n);sieve(n);
	for(int i=1;i<=n;i++) scanf("%d",&a[i]),pos[a[i]]=i;
	for(int i=1,u,v;i<n;i++) scanf("%d%d",&u,&v),adde(u,v),adde(v,u);
	dfs(1,0);for(int i=1;i<=n<<1;i++) st[i][0]=mp(dep[seq[i]],seq[i]);
	for(int i=1;i<=LOG_N;i++) for(int j=1;j+(1<<i)-1<=n<<1;j++)
		st[j][i]=min(st[j][i-1],st[j+(1<<i-1)][i-1]);
	for(int i=1;i<=n;i++) f[i]=solve(i);
	for(int i=1;i<=n;i++) for(int j=i;j<=n;j+=i){
		if(mu[j/i]==1) res[i]=(res[i]+f[j])%MOD;
		else if(mu[j/i]==-1) res[i]=(res[i]-f[j]+MOD)%MOD;
	} int ans=0;
	for(int i=1;i<=n;i++) ans=(ans+1ll*res[i]*i%MOD*qpow(phi[i]))%MOD;
	ans=1ll*ans*qpow(n)%MOD;ans=1ll*ans*qpow(n-1)%MOD;
	printf("%d\n",ans);
	return 0;
}
posted @ 2021-04-07 20:06  tzc_wk  阅读(57)  评论(0编辑  收藏  举报