Loading

Tree Cutting 题解

vjudge 题面

考虑暴力 dp:设 \(dp[i][j]\) 表示在以 \(i\) 为根的子树中选的包含节点 \(i\) 的连通块的价值为 \(j\) 的方案数,转移时合并儿子的 dp 状态即可。复杂度 \(O(n\times m^2)\)

这个暴力慢在合并两个儿子的状态,因为合并两个 dp 数组是 \(O(m^2)\) 的,只要用到基本上就寄了。

重新考虑连通块这个限制,其实就是对于每个点 \(u\),如果 \(u\) 不在联通块中,那么 \(u\) 的子树中所有点都不能在连通块中。进一步转化到 dfs 序上,就可以得到一个更好的 dp:设 \(dp[u][i][j]\) 表示在 \(u\) 的子树中,考虑到 dfs 序为 \(i\) 的点,连通块价值为 \(j\) 的方案数。当然这个连通块是强制要求包含 \(u\) 的。考虑转移,分类讨论连通块包不包含 \(i\),若包含则直接转移到 \(dp[u][i+1][j\oplus v_i]\),否则跳过整个 \(i\) 的子树,转移到 \(dp[u][i+sz_i][j]\)。这个 dp 是 \(O(n^2m)\) 的。

进一步优化,只需点分治优化枚举根的过程即可。

#include <bits/stdc++.h>
using namespace std;

int read() {
	int s = 0, f = 1;
	char ch = getchar();
	while (ch < '0' || ch > '9')
		f = (ch == '-' ? -1 : 1), ch = getchar();
	while (ch >= '0' && ch <= '9')
		s = (s << 1) + (s << 3) + (ch ^ 48), ch = getchar();
	return s * f;
}

const int mod = 1000000007;

int n, m, k;
int a[1005];
int head[1005], to[2005], nxt[2005], tot = 0;
int ht, Maxp, sz[1005];
bool f[1005];
int ans[1024], dfn[1005], rnk[1005], timer = 0;
int dp[1005][1024] = {{0}};

void add(int u, int v) {
	to[++tot] = v, nxt[tot] = head[u], head[u] = tot;
}

void H(int p, int fa, int SZ) {
	int maxp = 0;
	sz[p] = 1;
	for (int i = head[p]; i; i = nxt[i])
		if (to[i] != fa && !f[to[i]])
			H(to[i], p, SZ), sz[p] += sz[to[i]], maxp = max(maxp, sz[to[i]]);
	maxp = max(maxp, SZ - sz[p]);
	if (maxp < Maxp)
		Maxp = maxp, ht = p;
}

void dfs(int p, int fa) {
	dfn[p] = ++timer, rnk[timer] = p, sz[p] = 1;
	for (int i = head[p]; i; i = nxt[i])
		if (to[i] != fa && !f[to[i]])
			dfs(to[i], p), sz[p] += sz[to[i]];
}

void Solve(int rt) {
	Maxp = 1e9, H(rt, 0, sz[rt]), rt = ht;
	timer = 0, dfs(rt, 0);
	dp[2][a[rt]] = 1;
	for (int i = 2; i <= timer; i++)
		for (int j = 0; j < m; j++) {
			dp[i + 1][j ^ a[rnk[i]]] = (dp[i + 1][j ^ a[rnk[i]]] + dp[i][j]) % mod;
			dp[i + sz[rnk[i]]][j] = (dp[i + sz[rnk[i]]][j] + dp[i][j]) % mod;
		}
	for (int i = 0; i < m; i++)
		ans[i] = (ans[i] + dp[timer + 1][i]) % mod;
	for (int i = 1; i <= timer + 1; i++)
		for (int j = 0; j < m; j++)
			dp[i][j] = 0;
	f[rt] = true;
	for (int i = head[rt]; i; i = nxt[i])
		if (!f[to[i]])
			Solve(to[i]);
}

void init() {
	memset(f, 0, sizeof f);
	memset(ans, 0, sizeof ans);
	memset(head, 0, sizeof head);
	tot = 0;
}

signed main() {
	int T = read();
	while (T--) {
		init();
		n = read(), m = read();
		for (int i = 1; i <= n; i++)
			a[i] = read();
		for (int i = 1; i < n; i++) {
			int u = read(), v = read();
			add(u, v), add(v, u);
		}
		Solve(1);
		for (int i = 0; i < m; i++) {
			printf("%lld", ans[i]);
			if (i < m - 1)
				putchar(' ');
		}
		putchar('\n');
	}
	return 0;
}
posted @ 2023-03-28 16:39  Galetx  阅读(24)  评论(0编辑  收藏  举报