P11363 [NOIP2024] 树上遍历 题解
难是真的难,也值得好好整理一下。
首先,一个图的dfs生成树有一个性质:不含横插边。
所以在本题中,因为原图是一棵树,所以一个点 \(u\) 周围的边在生成树上一定在一条链上(首尾相连)。
那么,设 \(d_i\) 为 \(i\) 的度数,则对于一个点周围的边,一共有 \(d_i!\) 中排列方式。如果确定链的起点,则有 \((d_i-1)!\) 种情况。
(如下图的红色边)

所以对于 \(k=1\) 的情况下,答案就是 \(\Pi_{i=1}^n (d_i-1)!\)。因为每一个链的起点是定的。
我们称用来生成这棵dfs树的关键边为 生成树的根。
现在的问题在于,我们可能会重复计算一棵生成树多次。
所以现在我们需要考虑一棵生成树可能会被那些关键边给计算。
考虑下面这个图,这是以边 \((5,8)\) 构成的生成树。

其中灰色、粉红、橙色、绿色分别为点 \(5、3、4、1\) 相邻的边构成的链。其他的点因为度数为 \(1\) ,所以我们暂时不管他们。
因为一个点周围的边要构成一条链,所以 \((3,4)、(4,7)\) 就不能作为这棵生成树的根。因为从他们出发,点 \(3\) 的链就不是链了。
所以发现,上图的红色边就是可以作为这棵生成树根的边。
所以可以发现,能作为同一棵生成树的根的边,是原树中一个叶子节点到另一个叶子节点之间的路径上的边。并且与生成树建立起双射,就是一条这样的路径只对应一棵生成树,而一棵生成树也只对应一条这样的路径。
为什么这棵生成树一定存在这样一个路径可以作为根呢?
就如上图,我们是拿边 \((5,8)\) 生成的。此时向 \(8\) 节点走,发现 \(8\) 节点构成的链的起点一定是 \((5,8)\),而其终点 \((3,5)\) 则可以作为另一个根。于是就这样不断地从链的起点到链的终点,而链的终点又是下一个链的起点。这样不断遍历,最终会到达一个叶子节点。
因为需要保证,从可能的根节点出发遍历一棵树,原本钦定为链的部分必须还是链。
那么只要我从一个根向两边跳,每一次都从链顶跳到链尾,这样就可以找到一个路径。
所以说,我们只需要统计有多少个 原树中从一个叶子到另一个叶子之间的路径使得这条路径上至少有 \(k\) 条关键边中的一条。
那么如何统计方案数?
我们会发现,对于一条路径上可以作为根节点的边的端点,其链的起点与终点都是定的,单个方案数为 \((d_i-2)!\)。所以方案数为 \(\Pi_{x在链上}(d_x-2)!\ \Pi_{x不在链上}(d_x-1)!\)。为了方便,我们把它改成 \(\Pi(d_i-1)!\ \Pi_{x在链上}(d_x-1)^{-1}\)。
这样问题就变为:对于所有叶子到叶子的路径,满足其中至少一条边为关建边,路径上所有点的权值乘积之和。这里让一个点的点权为 \((d_x-1)^{-1}\)。
先把 \(n=2\) 判掉。然后对于 \(n>2\),找一个度数 \(>1\) 的位置作为根方便后续dp。
所以考虑dp,设 \(dp_{u,0/1}\) 表示以 \(u\) 为根的子树中,从 \(u\) 出发,不经过/经过关建边的乘积之和。
那么进行分类讨论,对于 \(v\in son_u\)。
-
若 \((u,v)\) 为关建边
则答案贡献增加 \((dp_{u,0}+dp_{u,1})(dp_{v,0}+dp_{v,1})\)。
然后让 \(dp_{u,1}\gets dp_{v,1}+dp_{v,0}\)。
-
若 \((u,v)\) 不是关建边
则答案贡献增加 \(dp_{v,1}(dp_{u,0}+dp_{i,1})+dp_{v,0}\times dp_{u,1}\)。
然后 \(dp_{u,1}\gets dp_{v,1}\),\(dp_{u,0}\gets dp_{v,0}\)。
然后输出答案即可。时间复杂度 \(O(n)\)。为了方便,处理逆元我用的 \(O(n\log n)\)。
#include<bits/stdc++.h>
using namespace std;
const int N=2e5+5,mod=1e9+7;
#define int long long
int head[N],cnt,n,k,u[N],v[N],w[N];
int ru[N],fac[N],inv[N],mul;
int dp[N][2],ans;
struct edge
{
int v,nxt,w;
}a[N<<1];
void add(int u,int v,int w)
{
a[++cnt].v=v;
a[cnt].w=w;
a[cnt].nxt=head[u];
head[u]=cnt;
}
int read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
void dfs(int u,int fa)
{
for(int i=head[u];i!=0;i=a[i].nxt)
{
int v=a[i].v;
if(v==fa) continue;
dfs(v,u);
if(a[i].w)
{
ans=(ans+(dp[v][0]+dp[v][1])*(dp[u][0]+dp[u][1])%mod)%mod;
dp[u][1]=(dp[u][1]+inv[ru[u]-1]*(dp[v][0]+dp[v][1]))%mod;
}
else
{
ans=(ans+dp[v][1]*(dp[u][0]+dp[u][1])%mod+dp[v][0]*dp[u][1]%mod)%mod;
dp[u][1]=(dp[u][1]+inv[ru[u]-1]*dp[v][1])%mod;
dp[u][0]=(dp[u][0]+inv[ru[u]-1]*dp[v][0])%mod;
}
}
if(ru[u]==1) dp[u][0]=inv[ru[u]-1];
}
void solve()
{
n=read(),k=read();
for(int i=1;i<=n;i++) head[i]=0,dp[i][0]=dp[i][1]=ru[i]=0;
cnt=0;
ans=0;
mul=1;
for(int i=1;i<n;i++) u[i]=read(),v[i]=read(),w[i]=0;
for(int i=1;i<=k;i++) w[read()]=1;
for(int i=1;i<n;i++)
{
ru[u[i]]++;
ru[v[i]]++;
add(u[i],v[i],w[i]),add(v[i],u[i],w[i]);
}
if(n==2)
{
printf("1\n");
return;
}
int pos=0;
for(int i=1;i<=n;i++)
{
mul=mul*fac[ru[i]-1]%mod;
if(ru[i]>1) pos=i;
}
dfs(pos,0);
printf("%d\n",ans*mul%mod);
}
int qpow(int a,int b)
{
int ans=1;
while(b>0)
{
if(b&1) ans=(ans*a)%mod;
a=(a*a)%mod;
b>>=1;
}
return ans;
}
signed main()
{
fac[0]=1,inv[0]=1;
for(int i=1;i<=1e5;i++) fac[i]=fac[i-1]*i%mod,inv[i]=qpow(i,mod-2);
int ID=read(),T=read();
while(T--) solve();
return 0;
}

浙公网安备 33010602011771号