「NOIP2024-树的遍历」题解
P11363 [NOIP2024] 树的遍历
sol
本篇题解来源于我此前某篇被删除的补题记录所以马蜂与语言风格可能与当前不同。
我们考虑,对于一棵新树,它可能从哪些边开始遍历,也就是哪些边可能是它的根节点。
然后你发现对于一棵新树,它可能的根节点,构成了从原树的一个叶子节点到另一个叶子节点的链。
你会发现,对于一个在这条链上的原树节点,其相邻的所有边在新树上的连边方式必然是:从一个可能根节点,以任意顺序连向不在链上的边,最后连向另一个可能根节点。
因此,如果链的两端不是原树叶子节点的话,其所连的其他原边中必还有一边可以作为新树根节点。
然后我们就可以对链构造了。
考虑对这样的一条链,其能构造出的不同新树数量。
首先我们考虑从其中一个根节点任意走能走出的种类(由于不能算重,故而只需统计一个根节点出发的情况即可):
\[\prod_i^n (dg_i-1)!
\]
也就是每一个点除了入边以外其余边任意排列连边。
然后考虑链固定时,每个链上节点最后走到的节点也是固定的,那么就需要除以:
\[\prod_{i\in L} (dg_i-1)
\]
\(L\) 表示链上的点集合。
那么对于一条链的答案就是:
\[\prod_i^n(dg_i-1)!\prod_{i\in L}(dg_i-1)^{-1}
\]
此外考虑可能出现的链,其中必然至少包含一个关键边。然后我们就可以使用树上 DP 解决这个问题。
code
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef __int128 i128;
typedef double db;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
typedef pair<int,ll> pil;
typedef pair<ll,int> pli;
template <typename Type>
using vec=vector<Type>;
template <typename Type>
using grheap=priority_queue<Type>;
template <typename Type>
using lrheap=priority_queue<Type,vector<Type>,greater<Type> >;
#define fir first
#define sec second
#define pub push_back
#define pob pop_back
#define puf push_front
#define pof pop_front
#define chmax(a,b) a=max(a,b)
#define chmin(a,b) a=min(a,b)
#define rep(i,x,y) for(int i=(x);i<=(y);i++)
#define per(i,x,y) for(int i=(x);i>=(y);i--)
#define repl(i,x,y) for(int i=(x);i<(y);i++)
#define file(f) freopen(#f".in","r",stdin);freopen(#f".out","w",stdout);
const int inf=0x3f3f3f3f;
const ll INF=0x3f3f3f3f3f3f3f3f;
const int mod=1e9+7/*998244353*/;
const int N=1e5+5;
int n,k;
int u[N],v[N],gj[N];
int dg[N];
vec<int> g[N];
ll inv[N];
ll ans;
ll f[N][2];
void dfs(int now,int fid){
ll val=inv[dg[now]-1];
f[now][0]=f[now][1]=0;
for(auto e:g[now]){
int nxt=(v[e]==now?u[e]:v[e]);
if(nxt==fid)continue;
dfs(nxt,now);
if(gj[e])f[nxt][1]+=f[nxt][0],f[nxt][1]%=mod,f[nxt][0]=0;
ans+=(f[now][1]*(f[nxt][0]+f[nxt][1])%mod+f[now][0]*f[nxt][1]%mod)%mod*val%mod,ans%=mod;
f[now][1]+=f[nxt][1],f[now][1]%=mod;
f[now][0]+=f[nxt][0],f[now][0]%=mod;
}
if(dg[now]==1){
f[now][0]=val;
}else{
f[now][0]*=val,f[now][0]%=mod;
f[now][1]*=val,f[now][1]%=mod;
}
}
void solve(){
cin>>n>>k;
rep(i,1,n)dg[i]=0,g[i].clear();
repl(i,1,n){
cin>>u[i]>>v[i];
gj[i]=0;
dg[u[i]]++;dg[v[i]]++;
g[u[i]].pub(i);
g[v[i]].pub(i);
}
rep(i,1,k){
int e;cin>>e;
gj[e]=1;
}
if(n==2){
cout<<"1\n";
return;
}
ans=0;
int rt=0;
rep(i,1,n)if(dg[i]>1){
rt=i;
break;
}
dfs(rt,rt);
rep(i,1,n)rep(j,1,dg[i]-1)ans*=j,ans%=mod;
cout<<ans<<"\n";
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0);cout.tie(0);
int c,t;
cin>>c>>t;
inv[0]=inv[1]=1;repl(i,2,N)inv[i]=(mod-mod/i)*inv[mod%i]%mod;
while(t--)solve();
return 0;
}

浙公网安备 33010602011771号