[NOI2020]命运(线段树合并优化dp)

题目:洛谷P6773LOJ#3340

题目描述:

给定一棵\(n\)个点、\(n-1\)条边的树,\(1\)号节点为根节点,树上每条边都可以取\(0\)或者\(1\),有\(m\)对限制,每对限制\((u,v)\),表示\(u\)\(v\)的路径上至少有一条边的值为\(1\),保证\(u\)\(v\)的祖先,问有多少中合法的方案对于所有\(m\)个限制都满足
\(n,m \leq 5 \cdot 10^{5}\)

蒟蒻题解:

容易想到容斥,但是只能做出部分分,一般来讲,容斥只能做限制数量较少的题

换个思路,树形\(dp\)

容易发现,对于一些下端点相同的限制,只需要上端点深度最大的限制满足,那么这些下端点相同的限制都能满足

对于一棵子树,假设两个端点都在这棵子树内的限制已经满足了,但是其他的限制不好做,怎么办呢?

只考虑定下了这棵子树内的边,其他的边未定下,初始全取\(0\),那么对于这棵子树内边的取值会直接影响到的限制是一个端点在子树内,另一个端点在子树外的且这个限制还没有被满足的情况,这种限制可以转换为一个端点为这棵子树的根节点,一个端点在子树外,在子树外的端点一定是这棵子树跟姐点的祖先

要满足这些限制,由于这些限制的下端点转换成了同一个点,所以只需要这棵子树的根节点和那些限制的上节点中深度最大的点满足就可以了

我们可以设\(f[x][y]\)表示以\(x\)这根节点的这棵子树内边的取值已经定下来,不满足的限制中上方节点的最大深度为\(y\)的方案数,要满足限制就是\(x\)到它祖先中深度为\(y\)的节点之间至少存在一条边的取值为\(1\)

枚举\(dp\)转移,假设现在在点\(x\),它有一个儿子为\(y\),根据把他们之间的这条边的值定义为\(0\)\(1\)去讨论可以推出:

\[f[x][i]=\sum_{j=0}^{dep[x]}f[y][j]\cdot f[x][i]+\sum_{j=0}^{i}f[y][j]\cdot f[x][i]+\sum_{j=0}^{i-1}f[x][j]\cdot f[y][i] \]

再化简一下,得到:

\[f[x][i]=f[x][i]\cdot (\sum_{j=0}^{dep[x]}f[y][j] + \sum_{j=0}^{i}f[y][j])+f[y][i]\cdot \sum_{j=0}^{i-1}f[x][j] \]

\(\sum\)用一个数组\(sum[x][i]\)代替,记录前缀和,\(sum[x][i]=f[x][1]+f[x][2]+···+f[x][i]\)

那么原式可以表示为:

\[f[x][i]=f[x][i]\cdot (sum[y][dep[x]]+sum[y][i])+f[y][i]\cdot sum[x][i-1] \]

这样我们就得到了一个\(\Theta(n^{2})\)的暴力做法了,能拿到\(36\)分的好成绩

由于它的数据有的点\(m\)很小,有的点满足完全二叉树,结合多种方法应该能拿到较高的分数

接着考虑优化这个\(dp\),由于它限制的条数只有\(m\)条,所以我们这个方程会有很多没用的东西,既耗时间又耗内存,考虑动态开点,又要将子树的信息合并起来,那么想到线段树合并,时间复杂度是\(\Theta(n\ logn)\)

参考程序:

#include<bits/stdc++.h>
using namespace std;
#define Re register int
typedef long long ll;

const int N = 500005;
const int M = 10000005;
const int p = 998244353;
int n, m, cnt, ht, num, rt[N], fa[N], dep[N], g[N], hea[N], nxt[N << 1], to[N << 1], lc[M], rc[M], sum[M], mul[M];

inline int read()
{
	char c = getchar();
	int ans = 0;
	while (c < 48 || c > 57) c = getchar();
	while (c >= 48 && c <= 57) ans = (ans << 3) + (ans << 1) + (c ^ 48), c = getchar();
	return ans;
}

inline int max(int x, int y)
{
	return (x > y) ? x : y;
}

inline int inc(int x, int y)
{
	x += y;
	return (x < p) ? x : x - p;
}

inline void add(int x, int y)
{
	nxt[++cnt] = hea[x], to[cnt] = y, hea[x] = cnt;
}

inline void dfs(int x)
{
	dep[x] = dep[fa[x]] + 1, ht = max(ht, dep[x]);
	for (Re i = hea[x]; i; i = nxt[i])
	{
		int u = to[i];
		if (u == fa[x]) continue;
		fa[u] = x;
		dfs(u);
	}
}

inline void push_down(int id)
{
	if (lc[id]) sum[lc[id]] = 1ll * sum[lc[id]] * mul[id] % p, mul[lc[id]] = 1ll * mul[lc[id]] * mul[id] % p;
	if (rc[id]) sum[rc[id]] = 1ll * sum[rc[id]] * mul[id] % p, mul[rc[id]] = 1ll * mul[rc[id]] * mul[id] % p;
	mul[id] = 1;
}

inline int bui(int l, int r, int x)
{
	int y = ++num;
	sum[y] = mul[y] = 1;
	if (l == r) return y;
	int mid = l + r >> 1;
	if (x <= mid) lc[y] = bui(l, mid, x);
	else rc[y] = bui(mid + 1, r, x);
	return y;
}

inline int merge(int id1, int id2, int l, int r, int &x, int &y)
{
	if (!id1 && !id2) return 0;
	if (!id1)
	{
		y = inc(y, sum[id2]);
		sum[id2] = 1ll * sum[id2] * x % p, mul[id2] = 1ll * mul[id2] * x % p;
		return id2;
	}
	if (!id2)
	{
		x = inc(x, sum[id1]);
		sum[id1] = 1ll * sum[id1] * y % p, mul[id1] = 1ll * mul[id1] * y % p;
		return id1;
	}
	if (l == r)
	{
		int z = sum[id1];
		y = inc(y, sum[id2]);
		sum[id1] = inc(1ll * sum[id1] * y % p, 1ll * sum[id2] * x % p);
		x = inc(x, z);
		return id1;
	}
	if (mul[id1] ^ 1) push_down(id1);
	if (mul[id2] ^ 1) push_down(id2);
	int mid = l + r >> 1;
	lc[id1] = merge(lc[id1], lc[id2], l, mid, x, y), rc[id1] = merge(rc[id1], rc[id2], mid + 1, r, x, y);
	sum[id1] = inc(sum[lc[id1]], sum[rc[id1]]);
	return id1;
}

inline int que(int id, int l, int r, int x)
{
	if (!id) return 0;
	if (r <= x) return sum[id];
	if (mul[id] ^ 1) push_down(id);
	int mid = l + r >> 1;
	if (x <= mid) return que(lc[id], l, mid, x);
	return inc(sum[lc[id]], que(rc[id], mid + 1, r, x));
}

inline void dfs1(int x)
{
	rt[x] = bui(0, ht, g[x]);
	for (Re i = hea[x]; i; i = nxt[i])
	{
		int u = to[i];
		if (u == fa[x]) continue;
		dfs1(u);
		int s1 = 0, s2 = que(rt[u], 0, ht, dep[x]);
		rt[x] = merge(rt[x], rt[u], 0, ht, s1, s2);
	}
}

int main()
{
	n = read();
	for (Re i = 1; i < n; ++i)
	{
		int u = read(), v = read();
		add(u, v), add(v, u);
	}
	dfs(1);
	m = read();
	for (Re i = 0; i < m; ++i)
	{
		int u = dep[read()], v = read();
		g[v] = max(g[v], u);
	}
	dfs1(1);
	printf("%d", que(rt[1], 0, ht, 0));
	return 0;
}
posted @ 2021-03-09 23:33  clfzs  阅读(261)  评论(0)    收藏  举报