51Nod1709 复杂度分析

51Nod1709 复杂度分析

题目链接

思路

一道很有意思的题。

我们考虑按位累计贡献。记录\(dp[i][j]\)为上一步倍增跳了\(2^j\)步,跳到了\(i\)点的所有状态"1"的个数和。

可以理解为,对于一个儿子u和他的祖先v的深度差,可以表示为一个二进制数。我们倍增的将u跳到v,每次跳的长度一定比上次小,且一定是\(2^k\)

对于一个点i,上一步跳了\(2^j\)步,那么就\(\forall k,k<j\)dp[fa][k]+=dp[i][j]+num[i][j]。其中fa代表的是i\(2^k\)个祖先,num代表的是某个状态有多少种方案。那么这句话的意思就是,对于所有可以由\((i,j)\)这个状态转移到的状态,都把他们的dp加上\((原状态dp值+原状态方案数那么多个“1”)\)。同时num[fa][k]+=num[i][k]ans+=dp[fa][k]*(size[fa]-size[from]-1)from指的是这一次是从fa的哪个亲儿子跳上来的。

那么连初始化都不需要了,直接多考虑一种情况跳跃即可。

代码

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define maxn 105000
#define ll long long
using namespace std;
int dp[maxn][23],f[maxn][23],num[maxn][23],size[maxn],dep[maxn],mem[maxn][23];
ll ans;
struct gg{
	int u,v,next;	
}side[maxn*2];
int head[maxn],cnt,n;
void insert(int u,int v){
	struct gg add={u,v,head[u]};side[++cnt]=add;head[u]=cnt;
}
void dfs1(int now,int fa){
	f[now][0]=fa;size[now]=1;dep[now]=dep[fa]+1;
	for(int i=1;i<=22;i++)f[now][i]=f[f[now][i-1]][i-1];
	for(int i=head[now];i;i=side[i].next){
		int v=side[i].v;if(v==fa)continue;
		dfs1(v,now);size[now]+=size[v];	
	}
}
int jump(int tar,int now){
	for(int i=22;i>=0;i--)if(dep[f[now][i]]>dep[tar])now=f[now][i];
	return now;
}
void init(){
	for(int i=1;i<=n;i++){
		for(int j=0;j<=22;j++){
			int tar=f[i][j];if(!tar)continue;
			mem[i][j]=jump(tar,i);
		}
	}
}
void dfs2(int now,int fa){
	for(int i=head[now];i;i=side[i].next){
		int v=side[i].v;if(v==fa)continue;
		dfs2(v,now);	
	}
	for(int i=1;i<=22;i++){//上次跳了2^i步 
		for(int j=0;j<i;j++){//这次跳了2^j步 
			int tar=f[now][j];//tar为这次跳跃的目的地
			if(!tar)continue; 
			dp[tar][j]+=dp[now][i]+num[now][i];
			num[tar][j]+=num[now][i];
			ans+=(dp[now][i]+num[now][i])*(size[tar]-size[mem[now][j]]);
			
		}
	}
	for(int j=0;j<=22;j++){//这次跳了2^j步 
			int tar=f[now][j];//tar为这次跳跃的目的地 
			if(!tar)continue; 
			dp[tar][j]+=1;
			num[tar][j]+=1;
			ans+=(size[tar]-size[mem[now][j]]);
			
	}
}
inline char nc(){
    static char buf[100000],*p1=buf,*p2=buf;
    return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int _read(){
    char ch=nc();int sum=0;
    while(!(ch>='0'&&ch<='9'))ch=nc();
    while(ch>='0'&&ch<='9')sum=sum*10+ch-48,ch=nc();
    return sum;
}
int main(){
	scanf("%d",&n);
	for(int i=1;i<n;i++){
		int u,v;u=_read();v=_read();insert(u,v);insert(v,u);	
	}
	dfs1(1,0);
	init();
	dfs2(1,0);
	cout<<ans;
}

posted @ 2019-09-12 21:20  GavinZheng  阅读(148)  评论(0编辑  收藏  举报