[状压dp] [树形dp] [计数问题] CF1799H Tree Cutting _ lone

posted on 2023-11-09 07:18:10 | under | source

很好的树上计数题,想通了就好做。

题意

给定一棵有根树。可以按顺序切掉 \(k\) 条边,并进入两个联通块中的一个。要求第 \(i\) 次进入的连通块大小 \(=a_i\),求合法方案数?

\(k\le 6\)

思路

看到题面和数据范围,应该想到树形 \(\rm dp\),并状压操作序列。

将操作分为:

  1. 进入子树。
  2. 进入非子树。

先关注操作 \(\rm 2\)。非子树部分的大小不好考虑,不妨转换为割掉的子树大小为 \(b_i=a_{i-1}-a_i\)。至此完成题目的初步转化。

然后思考其它限制。对操作序列(时间轴)进行讨论:

  • 对于操作 \(\rm1\):当 \(i\) 次操作为进入子树 \(v\) 时,则 \(i+1...n\) 操作都必定在 \(v\) 内进行。
  • 对于操作 \(\rm2\):没有太多限制,可以理解为删掉某个子树。只需满足大小限制即可。

不难发现操作 \(\rm2\) 更好做,先考虑全是它的情况。

定义 \(g_{u,i}\) 表示在 \(u\) 子树内,操作集合为 \(i\) 且只有操作 \(\rm 2\) 时的方案数。

然后想想转移,若按朴素的树形 \(\rm dp\) 做法,我们还需枚举 \(u\to v_i\) 的所有边。于是有个小技巧,状态多考虑 \(u\to fa\) 这条边。于是只需枚举子集转移,最后考虑这条边的贡献即可。

然后是操作 \(1\)。定义 \(f_{u,i,tim}\) 表示 \(u\) 子树操作集合为 \(i\),且最早的操作 \(\rm1\)\(tim\) 时刻的方案数。与 \(g\) 交替转移即可。

总结下 trick:

  1. 看数据范围考虑算法。
  2. 充分利用条件,转换限制使其相对宽松。
  3. 有多个数组需转移时,优先考虑简单的。
  4. 树形 \(\rm dp\) 时可以多向上考虑一条边。

具体转移还有其它细节就放到代码里了,复杂度 \(O(nk3^k)\)

代码

#include<bits/stdc++.h>
using namespace std;

#define int long long
#define pb push_back
#define ADD(a, b) a = ((a) + (b)) % mod
const int N = 5e3 + 5, M = 1 << 6, mod = 998244353;
int n, k, u, v, a[N], b[N];
int siz[N], ned[M + 5], mb[M + 5], g[N][M + 5], f[N][M + 5][15], ans;
vector<int>to[N];
//ned:状态i全是操作2时,删去的点数
//mb:i的最高位 

inline void init(){
	for(int i = 0; i < M;++i){
		for(int j = 5; j >= 0;--j) if((i >> j) & 1) {mb[i] = j; break;}
		for(int j = 0; j < 6;++j) if((i >> j) & 1) ned[i] += b[j + 1];
	}
	mb[0] = -114; //特殊情况,否则下面转移会错,样例都过不了 
}
inline void dfs(int u, int fa){
	g[u][0] = siz[u] = 1;
	for(auto v : to[u]){
		if(v ^ fa){
			dfs(v, u), siz[u] += siz[v];
			for(int i = M - 1; i > 0;--i) //注意先转移f再g 
				for(int tim = 0; tim < k;++tim)
					if((i >> tim) & 1)
						for(int j = i; j > 0;j = (j - 1) & i){ 
							if(mb[j] < tim) //选子树操作不在v内 
								ADD(f[u][i][tim], f[u][i ^ j][tim] * g[v][j] % mod);
							if(((j >> tim) & 1) && mb[i ^ j] < tim) //相反情况 
								ADD(f[u][i][tim], g[u][i ^ j] * f[v][j][tim] % mod);
						} 
			for(int i = M - 1; i > 0;--i)
				for(int j = i; j > 0;j = (j - 1) & i)
					ADD(g[u][i], g[u][i ^ j] * g[v][j] % mod); 
		}
	}
	if(fa){ //考虑删去u -> fa这条边 
		for(int i = M - 1; i > 0;--i)
			for(int tim = 0; tim < k;++tim)
				if((i >> tim) & 1){
					int j = i ^ (1 << tim);
					//即tim前面必须都是操作2 
					if(a[tim + 1] == (siz[u] - ned[j & ((1 << tim) - 1)])){ 
						for(int p = tim + 1; p < k;++p)
							ADD(f[u][i][tim], f[u][j][p]); //其余操作1必须在tim后面 
						ADD(f[u][i][tim], g[u][j]); //没有其余操作1 
					}
				} 
		for(int i = M - 1; i > 0;--i) //注意时间顺序:应该先修剪子树,再把整个子树剪掉 
			if(b[mb[i] + 1] == siz[u] - ned[i ^ (1 << mb[i])])
				ADD(g[u][i], g[u][i ^ (1 << mb[i])]);
	}
}
signed main(){
	cin >> n;
	for(int i = 1; i < n;++i) scanf("%lld%lld", &u, &v), to[u].pb(v), to[v].pb(u);
	cin >> k; a[0] = n;
	for(int i = 1; i <= k;++i) scanf("%lld", &a[i]), b[i] = a[i - 1] - a[i];
	
	init(), dfs(1, 0);
	
	ans = g[1][(1 << k) - 1];
	for(int i = 0; i < k;++i) ADD(ans, f[1][(1 << k) - 1][i]);
	cout << ans;
	return 0;
}
posted @ 2026-01-13 11:24  Zwi  阅读(0)  评论(0)    收藏  举报