[树形dp] Jzoj P5814 树
题解
- 首先,设d[i]为i的出度,f[i]为从i走向根节点的期望步数,g[i]为从根节点走到i的期望步数
- 那么,f[i]=1/d[i]+Σ(f[i]+f[son]+1)/d[i]
- 其实就是
- ①直接走到父亲的期望1/d[i]
- ②走到儿子,然后走回i,再到父亲的期望
- 那么,g[i]=1/d[fa[i]]+(1+g[i]+g[fa[i]])/d[fa[i]]+Σ(son!=i)(1+g[i]+f[son])/d[fa[i]]
- 其实就是
- ①直接走到i的期望1/d[i]
- ②走到爷爷,然后走回父亲,再走到i的期望
- ③走到非i的儿子,然后走回父亲,再走到i的期望
- 然后将f[i]化简一下:
- f[i]=1/d[i]+Σ(f[i]+f[son]+1)/d[i]
- d[i]*f[i]=1+Σ(f[i]+f[son]+1)
- d[i]*f[i]=1+(d[i]-1)*f[i]+Σf[son]+d[i]-1
- (d[i]-d[i]+1)*f[i]=d[i]+Σf[son]
- f[i]=d[i]+Σf[son]
- 再化简一下g[i]:
- g[i]=1/d[fa[i]]+(1+g[i]+g[fa[i]])/d[fa[i]]+Σ(son!=i)(1+g[i]+f[son])/d[fa[i]]
- d[fa[i]]*g[i]=1+1+g[i]+g[fa[i]]+Σ(son!=i)(1+g[i]+f[son])
- d[fa[i]]*g[i]=2+g[i]+g[fa[i]]+(d[i]-2)+(d[i]-2)*g[i]+Σf[son]
- d[fa[i]]*g[i]=d[i]+(d[i]-1)*g[i]+g[fa[i]]+Σf[son]
- g[i]=d[i]+g[fa[i]]+Σf[son]
- 然后就用dfs预处理出这两个数组
- f[i]=f[i]+f[fa],g[i]=g[i]+g[fa],求到最大的祖先的期望步数
- 对于一组询问u,v,路径就是u到lca,和v到lca
- 那么ans=f[u]-f[lca]+g[v]-g[lca]
代码
1 #include <cstdio> 2 #include <iostream> 3 #include <cstring> 4 #include <cmath> 5 using namespace std; 6 const long long mo=1e9+7; 7 struct edge {int to,from;}e[1000010*2]; 8 int n,q,head[1000010],cnt; 9 long long sum,f[1000010][20],deep[1000010],ans,k1[1000010],k2[1000010]; 10 void insert(int x,int y) {e[++cnt].to=y; e[cnt].from=head[x]; head[x]=cnt;} 11 void dfs(int x,int fa) 12 { 13 f[x][0]=fa; 14 k1[x]=0; 15 for (int i=head[x];i;i=e[i].from) 16 if (e[i].to!=fa) 17 { 18 dfs(e[i].to,x); 19 k1[x]=(k1[x]+k1[e[i].to]+1)%mo; 20 } 21 k1[x]=(k1[x]+1)%mo; 22 } 23 void dfs1(int x,int fa) 24 { 25 long long sum=0; 26 for (int i=head[x];i;i=e[i].from) 27 if (e[i].to!=fa) sum=(sum+k1[e[i].to]+1)%mo; 28 else sum=(sum+k2[x]+1)%mo; 29 for (int i=head[x];i;i=e[i].from) 30 if (e[i].to!=fa) 31 { 32 k2[e[i].to]=((sum-k1[e[i].to])%mo+mo)%mo; 33 dfs1(e[i].to,x); 34 } 35 } 36 void dfs2(int x,int fa) 37 { 38 deep[x]=deep[fa]+1; 39 k1[x]=(k1[x]+k1[fa])%mo,k2[x]=(k2[x]+k2[fa])%mo; 40 for (int i=head[x];i;i=e[i].from) 41 if (e[i].to!=fa) 42 dfs2(e[i].to,x); 43 } 44 int getlca(int u,int w) 45 { 46 if (deep[u]<deep[w]) swap(u,w); 47 int d=deep[u]-deep[w]; 48 if (d) for (int i=0;i<=log(n)/log(2)+1&&d;i++,d>>=1) if (d&1) u=f[u][i]; 49 if (u==w) return u; 50 for (int i=log(n)/log(2)+1;i>=0;i--) 51 if (f[u][i]!=f[w][i]) 52 u=f[u][i],w=f[w][i]; 53 return f[u][0]; 54 } 55 int main() 56 { 57 freopen("tree.in","r",stdin); 58 freopen("tree.out","w",stdout); 59 scanf("%d%d",&n,&q); 60 for (int i=1;i<=n-1;i++) 61 { 62 int u,v; 63 scanf("%d%d",&u,&v); 64 insert(u,v),insert(v,u); 65 } 66 dfs(1,-1); 67 k1[1]=k2[1]=0; 68 dfs1(1,-1); 69 dfs2(1,-1); 70 for (int i=1;i<=log(n)/log(2)+1;i++) 71 for (int j=1;j<=n;j++) 72 f[j][i]=f[f[j][i-1]][i-1]; 73 for (int i=1;i<=q;i++) 74 { 75 int u,v; 76 scanf("%d%d",&u,&v); 77 int lca=getlca(u,v); 78 printf("%lld\n",(k1[u]-k1[lca]+k2[v]-k2[lca]+mo)%mo); 79 } 80 return 0; 81 }