Codeforces 1010F - Tree(分治 NTT+树剖)

Codeforces 题面传送门 & 洛谷题面传送门

神仙题。

首先我们考虑按照这题的套路,记 \(t_i\) 表示 \(i\) 上的果子数量减去其儿子果子数量之和,那么对于一个合法的放置果子的方案,必然有 \(t_i\ge 0,\sum\limits_{i=1}^nt_i=X\),而根据隔板法,对于一个大小为 \(s\) 的连通块,分配该连通块中的 \(t_i\) 的方案数为 \(\dbinom{X+s-1}{s-1}\),因此假设大小为 \(i\) 的连通块个数为 \(c_i\),那么答案即为 \(\sum\limits_{i=1}^nc_i\dbinom{X+i-1}{i-1}\),后面那坨组合数显然可以递推求出,因此我们的任务就是求出 \(c_i\)

考虑树形 DP,\(dp_{i,j}\) 表示 \(i\) 子树内大小为 \(j\) 的连通块有多少个,那么假设 \(u,v\)\(i\) 的两个儿子,那么显然有 \(dp_{i,j}=\sum\limits_{p+q=j-1}dp_{u,p}dp_{v,q}\)\(dp_{i,0}=1\)。直接做是平方的,不过注意到这东西一脸卷的样子,因此考虑设 \(F_i(x)\) 表示 \(dp_i\) 的 OGF,那么有 \(F_i(x)=F_u(x)F_v(x)·x+1\),但是直接卷依然会被卡成 \(n^2\log n\)。这时候就要用到类似于**动态 \(dp\) **的思路了,我们考虑对整棵树进行树链剖分并对于一条重链上的节点我们直接计算其链顶上的 \(dp\) 值,那么假设这条重链上的点从上至下分别是 \(p_1,p_2,p_3,\cdots,p_k\),并且对于每个 \(p_i(i\ne k)\),其不同于 \(p_{i+1}\) 的另一个儿子为 \(q_i\)(如果这样的儿子不存在那我们默认 \(q_i=0,F_{q_i}=1\)),那么显然 \(F_{p_k}=x+1,F_{p_i}=xF_{q_i}F_{p_{i+1}}+1\),递推以下可知 \(F_{p_1}=\sum\limits_{i=1}^k\prod\limits_{j=k-i+1}^kxF_{q_i}\)。因此到这里,该问题可以转化为,有 \(k\) 个生成函数 \(a_1,a_2,\cdots,a_k\),要求 \(a_1+a_1a_2+a_1a_2a_3+\cdots+a_1a_2\cdots a_k\) 的值。该问题可以分治 FFT(其实严格来说不能叫分治 FFT,因为没有 cdq 分治) 求解,具体来说当我们分治到区间 \([l,r]\) 时我们求出 \(A_{l,r}=\prod\limits_{i=l}^ra_i\),以及 \(B_{l,r}=\sum\limits_{i=l}^r\prod\limits_{j=i}^ra_j\),那么设 \(mid=\lfloor\dfrac{l+r}{2}\rfloor\),有 \(A_{l,r}=A_{l,mid}A_{mid+1,r},B_{l,r}=B_{l,mid}+A_{l,mid}B_{mid+1,r}\),分治求解即可。

至于复杂度,大概就对于每个点,其最多会对 \(\log n\) 个多项式的度产生 \(1\) 的贡献,再加上分治 FFT 的复杂度,总复杂度 \(n\log^3n\)

u1s1 vector 实现的 NTT 是真的好写,但常数是真的大……把多项式卷积里的 vector 换成数组后常数竟一下子将为原来的 \(\dfrac{1}{8}\)……

贴个稍微能看点的代码:

const int MAXN=1e5;
const int MAXP=1<<18;
const int pr=3;
const int ipr=332748118;
const int MOD=998244353;
int inv[MAXN+5];
int qpow(int x,int e){
	int ret=1;
	for(;e;e>>=1,x=1ll*x*x%MOD) if(e&1) ret=1ll*ret*x%MOD;
	return ret;
}
int rev[MAXP+5],A[MAXP+5],B[MAXP+5],C[MAXP+5];
void NTT(int *a,int len,int type){
	int lg=31-__builtin_clz(len);
	for(int i=0;i<len;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<lg-1);
	for(int i=0;i<len;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
	for(int i=2;i<=len;i<<=1){
		int W=qpow((type<0)?ipr:pr,(MOD-1)/i);
		for(int j=0;j<len;j+=i){
			for(int k=0,w=1;k<(i>>1);k++,w=1ll*w*W%MOD){
				int X=a[j+k],Y=1ll*a[(i>>1)+j+k]*w%MOD;
				a[j+k]=(X+Y)%MOD;a[(i>>1)+j+k]=(X-Y+MOD)%MOD;
			}
		}
	}
	if(!~type){
		int ivn=qpow(len,MOD-2);
		for(int i=0;i<len;i++) a[i]=1ll*a[i]*ivn%MOD;
	}
}
vector<int> conv(vector<int> x,vector<int> y){
	if(x.size()+y.size()<=16){
		vector<int> res(x.size()+y.size()-1);
		for(int i=0;i<x.size();i++) for(int j=0;j<y.size();j++)
			res[i+j]=(res[i+j]+1ll*x[i]*y[j])%MOD;
		return res;
	}
	int LEN=1;while(LEN<x.size()+y.size()) LEN<<=1;
	for(int i=0;i<LEN;i++) A[i]=B[i]=C[i]=0;
	for(int i=0;i<x.size();i++) A[i]=x[i];
	for(int i=0;i<y.size();i++) B[i]=y[i];
	NTT(A,LEN,1);NTT(B,LEN,1);
	for(int i=0;i<LEN;i++) C[i]=1ll*A[i]*B[i]%MOD;NTT(C,LEN,-1);
	vector<int> res;for(int i=0;i<x.size()+y.size()-1;i++) res.pb(C[i]);
	return res;
}
int n,hd[MAXN+5],to[MAXN*2+5],nxt[MAXN*2+5],ec=0;ll X;
void adde(int u,int v){to[++ec]=v;nxt[ec]=hd[u];hd[u]=ec;}
int siz[MAXN+5],wson[MAXN+5],fa[MAXN+5];
void dfs0(int x,int f){
	siz[x]=1;fa[x]=f;
	for(int e=hd[x];e;e=nxt[e]){
		int y=to[e];if(y==f) continue;
		dfs0(y,x);siz[x]+=siz[y];
		if(siz[y]>siz[wson[x]]) wson[x]=y;
	}
}
vector<vector<int> > p;
vector<int> dp[MAXN+5];
void addzero(vector<int> &v){v.insert(v.begin(),0);}
vector<int> plus_one(vector<int> v){(v[0]+=1)%=MOD;return v;}
vector<int> minus_one(vector<int> v){(v[0]+=MOD-1)%=MOD;return v;}
vector<int> plus_vec(vector<int> a,vector<int> b){
	if(a.size()<b.size()) swap(a,b);b.resize(a.size(),0);
	for(int i=0;i<a.size();i++) (a[i]+=b[i])%=MOD;
	return a;
}
pair<vector<int>,vector<int> > solve(int l,int r){
	if(l==r) return mp(plus_one(p[l]),p[l]);int mid=l+r>>1;
	pair<vector<int>,vector<int> > L=solve(l,mid);
	pair<vector<int>,vector<int> > R=solve(mid+1,r);
//	printf("%d %d %d %d!!!\n",L.fi.size(),L.se.size(),R.fi.size(),R.se.size());
	return mp(plus_vec(conv(minus_one(L.fi),R.se),R.fi),conv(L.se,R.se));
}
void dfs(int x){
	if(!wson[x]) return dp[x]=vector<int>(2,1),void();
	for(int y=x;y;y=wson[y]){
		for(int e=hd[y];e;e=nxt[e]){
			int z=to[e];if(z==fa[y]||z==wson[y]) continue;
			dfs(z);
		}
	} p.clear();
	for(int y=x;y;y=wson[y]){
		bool flg=0;
		for(int e=hd[y];e;e=nxt[e]){
			int z=to[e];if(z==fa[y]||z==wson[y]) continue;
			addzero(dp[z]);p.pb(dp[z]);flg=1;
		} if(!flg) p.pb(minus_one(vector<int>(2,1)));
	} assert(!p.empty());reverse(p.begin(),p.end());
	dp[x]=solve(0,p.size()-1).fi;//printf("%d: ",x);
//	for(int i=0;i<dp[x].size();i++) printf("%d ",dp[x][i]);
//	printf("!\n");
}
int main(){
	scanf("%d%lld",&n,&X);X%=MOD;
	for(int i=1,u,v;i<n;i++){scanf("%d%d",&u,&v);adde(u,v),adde(v,u);}
	dfs0(1,0);dfs(1);int ans=0;
	for(int i=(inv[0]=inv[1]=1)+1;i<=n;i++) inv[i]=1ll*inv[MOD%i]*(MOD-MOD/i)%MOD;
//	for(int i=0;i<dp[1].size();i++) printf("%d%c",dp[1][i]," \n"[i+1==dp[1].size()]);
	for(int i=1,t=1;i<dp[1].size();i++){
		ans=(ans+1ll*dp[1][i]*t)%MOD;
		t=1ll*t*(X+i)%MOD*inv[i]%MOD;
	} printf("%d\n",ans);
	return 0;
}
posted @ 2021-08-06 22:27  tzc_wk  阅读(57)  评论(0编辑  收藏  举报