[LOJ2983] [WC2019] 数树
题目链接
LOJ:https://loj.ac/problem/2983
BZOJ:https://lydsy.com/JudgeOnline/problem.php?id=5475
洛谷:https://www.luogu.org/problemnew/show/P5206
Soltion
超级毒瘤数数题...窝看了一晚上才看懂...
%%%rqy
subtask 0
很容易可以看出每个公共的边连成的联通块被绑一起了,所以答案就是\(y^{p}\),其中\(p\)为联通块个数。
也就是说答案为\(y^{n-m}\),其中\(m\)为公共边条数。
Subtask 1
设\(f(s)=y^{n-|s|}\),其中\(s\)为边集。
那么我们枚举树的形态可以得到答案为:
其中\(E_2\)枚举的是边集,\(E_1\)是题目给出的边集。这个\(E_1\cap E_2\)不是很好办,我们可以利用下面这个容斥的式子化简:
下面会给出证明,当然,读者自证不难。带进去可得:
其中倒数第二个等式到最后一个等式用到了二项式定理,\(g(s)\)表示包含边集\(s\)的树的个数,可以证明:
证明下面会给出,其中,\(s\)被分成了\(k\)个联通块,大小分别为\(a_1,a_2\cdots a_k\)。
这套交换求和符号还是看得懂的吧...不然怎么敢来刚这个题
那么带进去可得:
至此,这玩意还是指数级的,但是我们可以发现,\(\frac{ny}{1-y}\)是固定的,设其为\(k\),也就是说,每个联通块有\(ka_i\)的贡献。
那么,其实我们就可以\(dp\)了,设\(f_{i,j}\)表示第\(i\)个点,只考虑子树\(i\)所在的联通块\(size\)为\(j\)的贡献,注意不考虑\(i\)这个联通块的贡献(这里贡献指的是若干个\(ka_i\)的乘积)。
然后直接暴力背包转移,时间复杂度降为了\(O(n^2)\)。
我们固定\(f_{i,j}\)的\(i\),设一个幂级数:
再设答案为\(g_i(x)\):
其中\(F_i'\)表示求导。
那么背包转移就可以写成:
那么可得:
这里是照抄的乘积求导法则:
换一下求和符号就是:
然后令\(x=1\)得到:
令\(t_i=F_i(1)=\prod_{j\in son_i}(g_j+F_j(x))\),那么把上面的式子抄下来:
然后\(O(n)\)递推就好了。
Subtask 2
我们照抄上面的式子:
这里\(s\)枚举的是公共部分,所以\(g(s)\)要平方,因为两棵树都要满足。
展开:
过程和上面一样。
我们换种方式枚举,枚举多少个联通块以及大小分别是多少,显然联通块无顺序所以要除以\(k!\),然后点有标号所以乘上组合数\(\frac{n!}{\prod a_i!}\),\(a\)个点的树个数为\(a^{a-2}\),写出来就是:
后面的卷积形式写成多项式就是:
注意到这是个多项式\(\exp\)形式,直接算就好了,复杂度\(O(n\log n)\)。
其实多项式\(\exp\)的组合意义就是带标号的多重背包,所以上面符合多项式\(\exp\)的式子也就不出意外了。
上面lemma的proof
lemma 1:
其中\(f\)是一个和\(|s|\)有关的函数。
那么交换求和符号:
可以发现等式左边右边都是\(f(s)\),证毕。
lemma 2:
我们要把\(k\)个联通块组成的森林连成一棵树的方案数,其中联通块大小分别为\(a_1,a_2\cdots a_k\)。
考虑每个联通块视为一个点,那么它的\(prufer\)序列共有\(k-2\)个点,每个点在\([1,k]\)。
我们枚举每个点是什么,统计方案数:
其中\(b_i\)为第\(i\)位的数,\(c_i\)表示\(i\)出现了多少次。
交换求和符号:
注意当\(y=1\)时上面很多式子都没有意义,需要特判。
#include<bits/stdc++.h>
using namespace std;
void read(int &x) {
x=0;int f=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}
void print(int x) {
if(x<0) putchar('-'),x=-x;
if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}
#define lf double
#define ll long long
#define mp make_pair
#define fr first
#define sc second
const int maxn = 6e5+10;
const int inf = 1e9;
const lf eps = 1e-8;
const int mod = 998244353;
int add(int x,int y) {return x+y>mod?x+y-mod:x+y;}
int del(int x,int y) {return x-y<0?x-y+mod:x-y;}
int mul(int x,int y) {return 1ll*x*y-1ll*x*y/mod*mod;}
int n,y;
int qpow(int a,int x) {
int res=1;
for(;x;x>>=1,a=mul(a,a)) if(x&1) res=mul(res,a);
return res;
}
namespace fuckpps {
void solve(int op) {
if(op==0) write(1);
else if(op==1) write(qpow(n,n-2));
else write(qpow(n,2*n-4));
}
}
namespace subtask0 {
void solve() {
map<pair<int,int >,int > s;int ans=0;
for(int i=1,x,yy;i<n;i++) read(x),read(yy),s[mp(min(x,yy),max(x,yy))]=1;
for(int i=1,x,yy;i<n;i++) read(x),read(yy),ans+=s[mp(min(x,yy),max(x,yy))];
write(qpow(y,n-ans));
}
}
namespace subtask1 {
int head[maxn],tot,g[maxn],t[maxn],k;
struct edge{int to,nxt;}e[maxn<<1];
void ad(int u,int v) {e[++tot]=(edge){v,head[u]},head[u]=tot;}
void ins(int u,int v) {ad(u,v),ad(v,u);}
void dfs(int x,int fa) {
g[x]=k,t[x]=1;
for(int i=head[x],v;i;i=e[i].nxt)
if((v=e[i].to)!=fa) {
dfs(v,x);
g[x]=add(g[x],mul(g[v],qpow(add(g[v],t[v]),mod-2)));
t[x]=mul(t[x],add(g[v],t[v]));
}
g[x]=mul(g[x],t[x]);
}
void solve() {
k=mul(mul(n,y),qpow(1-y+mod,mod-2));
for(int i=1,u,v;i<n;i++) read(u),read(v),ins(u,v);
dfs(1,0);write(mul(g[1],mul(qpow(1-y+mod,n),qpow(mul(n,n),mod-2))));
}
}
namespace subtask2 {
int w[maxn],rw[maxn],N,bit,pos[maxn],f[maxn],fac[maxn],ifac[maxn],inv[maxn],mxn,g[maxn],K;
int tmp[7][maxn];
void ntt_init() {
w[0]=rw[0]=1,w[1]=qpow(3,(mod-1)/N);mxn=N;
for(int i=2;i<=N;i++) w[i]=mul(w[i-1],w[1]);
rw[1]=qpow(w[1],mod-2);
for(int i=2;i<=N;i++) rw[i]=mul(rw[i-1],rw[1]);
}
void ntt(int *r,int op) {
for(int i=1;i<N;i++) if(pos[i]>i) swap(r[i],r[pos[i]]);
for(int i=1,d=mxn>>1;i<N;i<<=1,d>>=1)
for(int j=0;j<N;j+=i<<1)
for(int k=0;k<i;k++) {
int x=r[j+k],y=mul((op==1?w:rw)[k*d],r[i+j+k]);
r[j+k]=add(x,y),r[i+j+k]=del(x,y);
}
if(op==-1) {int d=qpow(N,mod-2);for(int i=0;i<N;i++) r[i]=mul(r[i],d);}
}
void ntt_get(int len) {
for(N=1,bit=0;N<=len;N<<=1,bit++) ;
for(int i=0;i<N;i++) pos[i]=pos[i>>1]>>1|((i&1)<<(bit-1));
}
void poly_inv(int *r,int *t,int len) {
if(len==1) return t[0]=qpow(r[0],mod-2),void();
poly_inv(r,t,len>>1);
for(int i=0;i<len>>1;i++) tmp[0][i]=t[i],tmp[1][i]=r[i];
for(int i=len>>1;i<len;i++) tmp[0][i]=0,tmp[1][i]=r[i];
ntt_get(len),ntt(tmp[0],1),ntt(tmp[1],1);
for(int i=0;i<N;i++) t[i]=del(mul(2,tmp[0][i]),mul(mul(tmp[1][i],tmp[0][i]),tmp[0][i]));
ntt(t,-1);for(int i=len;i<N;i++) t[i]=0;
for(int i=0;i<len<<1;i++) tmp[0][i]=tmp[1][i]=0;
}
void poly_der(int *r,int *t,int len) {
ntt_get(len);
for(int i=1;i<len;i++) t[i-1]=mul(i,r[i]);
for(int i=len-1;i<N;i++) t[i]=0;
}
void poly_int(int *r,int *t,int len) {
ntt_get(len);
for(int i=0;i<len;i++) t[i+1]=mul(inv[i+1],r[i]);t[0]=0;
for(int i=len+1;i<N;i++) t[i]=0;
}
void poly_ln(int *r,int *t,int len) {
poly_der(r,tmp[2],len);
poly_inv(r,tmp[3],len);
ntt_get(len),ntt(tmp[2],1),ntt(tmp[3],1);
for(int i=0;i<N;i++) tmp[3][i]=mul(tmp[2][i],tmp[3][i]);
ntt(tmp[3],-1);
poly_int(tmp[3],t,len);
for(int i=0;i<len<<1;i++) tmp[3][i]=0;
}
void poly_exp(int *r,int *t,int len) {
if(len==1) return t[0]=1,void();
poly_exp(r,t,len>>1);
for(int i=0;i<len>>1;i++) tmp[4][i]=r[i],tmp[5][i]=t[i];
for(int i=len>>1;i<len;i++) tmp[4][i]=r[i],tmp[5][i]=0;
poly_ln(tmp[5],tmp[6],len);
for(int i=0;i<len;i++) tmp[4][i]=del(tmp[4][i],tmp[6][i]);
tmp[4][0]=add(tmp[4][0],1);
ntt_get(len),ntt(tmp[4],1),ntt(tmp[5],1);
for(int i=0;i<N;i++) t[i]=mul(tmp[4][i],tmp[5][i]);
ntt(t,-1);for(int i=len;i<N;i++) t[i]=0;
}
void solve() {
for(N=1,bit=0;N<=n<<2;N<<=1,bit++);ntt_init();
inv[0]=fac[0]=ifac[0]=inv[1]=1;K=mul(mul(n,mul(n,y)),qpow(1-y+mod,mod-2));
for(int i=2;i<N;i++) inv[i]=mul(mod-mod/i,inv[mod%i]);
for(int i=1;i<N;i++) fac[i]=mul(fac[i-1],i);
for(int i=1;i<N;i++) ifac[i]=mul(ifac[i-1],inv[i]);
for(int i=1;i<=n;i++) f[i]=mul(K,mul(qpow(i,i),ifac[i]));
ntt_get(n);poly_exp(f,g,N);
write(mul(g[n],mul(mul(qpow(1-y+mod,n),fac[n]),qpow(qpow(n,4),mod-2))));
}
}
int main() {
read(n),read(y);int op;read(op);
if(y==1) fuckpps :: solve(op);
else if(op==0) subtask0 :: solve();
else if(op==1) subtask1 :: solve();
else subtask2 :: solve();
return 0;
}