洛谷P6773 [NOI2020]命运(整体 dp)

洛谷P6773 [NOI2020]命运(整体 dp)

题目大意

形式化的:给定一棵树 \(T = (V, E)\) 和点对集合 \(\mathcal Q \subseteq V \times V\) ,满足对于所有 \((u, v) \in \mathcal Q\),都有 \(u \neq v\),并且 \(u\)\(v\) 在树 \(T\) 上的祖先。其中 \(V\)\(E\) 分别代表树 \(T\) 的结点集和边集。求有多少个不同的函数 \(f\) : \(E \to \{0, 1\}\)(将每条边 \(e \in E\)\(f(e)\) 值置为 \(0\)\(1\)),满足对于任何 \((u, v) \in \mathcal Q\),都存在 \(u\)\(v\) 路径上的一条边 \(e\) 使得 \(f(e) = 1\)。由于答案可能非常大,你只需要输出结果对 \(998,244,353\)(一个素数)取模的结果。

数据范围

全部数据满足\(n \leq 5 \times 10^5\)\(m \leq 5 \times 10^5\)。输入构成一棵树,并且对于 \(1 \leq i \leq m\)\(u_i\) 始终为 \(v_i\) 的祖先结点。

解题思路

首先题目大意就是每个边可以是 0 或 1,有 m 个条件,要求树上的一条直上直下的链中至少有一个 1,求方案数。

考虑 dp,\(dp[x][y]\) 表示 x 子树内的边状态已经确定,下端点在子树内且没有被满足的条件中,上端点最深的深度是 y,特殊地,如果子树内的条件都满足那么 y = 0。记录最深是因为深得满足浅的也就会跟着满足。

考虑如何转移,对于 x 的儿子点 y,有

\[\large dp'[x][i] = \sum_{j=0}^{dep_x} dp[x][i] \times dp[y][j]+ \sum_{j=0}^i dp[x][i] \times dp[y][j] \large + \sum_{j=0}^{i-1} dp[y][i] \times dp[x][j] \]

简单来说就是看这条边是 1 还是 0,前缀和优化一下就有 64 pt,其中如果建出来虚树就能再拿 8 pt。

现在是正解时间,重新审查我们的 dp

\[\large dp'[x][i] = dp[x][i] \times (sum[y][dep_x]+ sum[y][i]) + dp[y][i] \times sum[x][i-1] \]

考虑线段树合并,\(sum[y][i]\) 可以先从线段树上查一下,剩下的都是和下标相关的,我们先合并左子树然后合并右子树,走右子树的时候就可以把左边的 sum 加上了,整体乘一个数可以打个标记即可,细节如下

// s1 -> (sum[y][dep[x]]+sum[y][i]), s2 -> sum[x][i-1]
int merge(int x, int y, int l, int r, ll &s1, ll &s2) {
	if (!x && !y) return 0;
	if (!x || !y) {
		if (!x) {
			add(s1, sum(y));
			mul(y) = mul(y) * s2 % P;
			sum(y) = sum(y) * s2 % P;
			return y;
		}
		add(s2, sum(x));
		sum(x) = sum(x) * s1 % P, mul(x) = mul(x) * s1 % P;
		return x;
	}
	if (l == r) {
		ll tx = sum(x), ty = sum(y); add(s1, ty);
		sum(x) = (sum(x) * s1 + sum(y) * s2) % P;
		add(s2, tx); 
		return x;
	}
	spread(x), spread(y); //下放标记
	int mid = (l + r) >> 1;
	ls(x) = merge(ls(x), ls(y), l, mid, s1, s2); 
	rs(x) = merge(rs(x), rs(y), mid + 1, r, s1, s2);
	sum(x) = (sum(ls(x)) + sum(rs(x))) % P;
	return x;
}

/*
     />  フ
     |  _  _|
     /`ミ _x 彡
     /      |
    /   ヽ   ?
 / ̄|   | | |
 | ( ̄ヽ__ヽ_)_)
 \二つ
 */

#include <queue>
#include <vector>
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define MP make_pair
#define ll long long
#define fi first
#define se second
using namespace std;

template <typename T>
void read(T &x) {
    x = 0; bool f = 0;
    char c = getchar();
    for (;!isdigit(c);c=getchar()) if (c=='-') f=1;
    for (;isdigit(c);c=getchar()) x=x*10+(c^48);
    if (f) x=-x;
}

template<typename F>
inline void write(F x, char ed = '\n') {
	static short st[30];short tp=0;
	if(x<0) putchar('-'),x=-x;
	do st[++tp]=x%10,x/=10; while(x);
	while(tp) putchar('0'|st[tp--]);
	putchar(ed);
}

template <typename T>
inline void Mx(T &x, T y) { x < y && (x = y); }

template <typename T>
inline void Mn(T &x, T y) { x > y && (x = y); }

const int P = 998244353;
const int N = 500050;
vector<int> v[N];
int h[N], ne[N<<1], to[N<<1], dep[N], tot, m, n;
inline void adde(int x, int y) {
	ne[++tot] = h[x], to[h[x] = tot] = y;
}

inline void add(ll &x, ll y) { x += y, (x >= P) && (x -= P); }

struct node {
	int ls, rs;
	ll sum, mul;
	#define mul(p) t[p].mul
	#define ls(p) t[p].ls
	#define rs(p) t[p].rs
	#define sum(p) t[p].sum
}t[N<<5];

void spread(int p) {
	if (ls(p)) {
		sum(ls(p)) = sum(ls(p)) * mul(p) % P;
		mul(ls(p)) = mul(ls(p)) * mul(p) % P;
	}
	if (rs(p)) {
		sum(rs(p)) = sum(rs(p)) * mul(p) % P;
		mul(rs(p)) = mul(rs(p)) * mul(p) % P;
	}
	mul(p) = 1;
}

ll query(int rt, int l, int r, int L) {
	if (!rt || r <= L) return sum(rt);
	int mid = (l + r) >> 1; ll sum = 0;
	spread(rt);
	if (mid < L) add(sum, query(rs(rt), mid + 1, r, L));
	add(sum, query(ls(rt), l, mid, L));
	return sum;
}

int cnt;
void change(int &p, int l, int r, int x) {
	p = ++cnt, sum(p) = mul(p) = 1;
	if (l == r) return;
	int mid = (l + r) >> 1;
	if (x <= mid) change(ls(p), l, mid, x);
	else change(rs(p), mid + 1, r, x);
}

int merge(int x, int y, int l, int r, ll &s1, ll &s2) {
	if (!x && !y) return 0;
	if (!x || !y) {
		if (!x) {
			add(s1, sum(y));
			mul(y) = mul(y) * s2 % P;
			sum(y) = sum(y) * s2 % P;
			return y;
		}
		add(s2, sum(x));
		sum(x) = sum(x) * s1 % P, mul(x) = mul(x) * s1 % P;
		return x;
	}
	if (l == r) {
		ll tx = sum(x), ty = sum(y); add(s1, ty);
		sum(x) = (sum(x) * s1 + sum(y) * s2) % P;
		add(s2, tx);
		return x;
	}
	spread(x), spread(y);
	int mid = (l + r) >> 1;
	ls(x) = merge(ls(x), ls(y), l, mid, s1, s2); 
	rs(x) = merge(rs(x), rs(y), mid + 1, r, s1, s2);
	sum(x) = (sum(ls(x)) + sum(rs(x))) % P;
	return x;
}

int T[N];
void dfs(int x, int fa) {
	dep[x] = dep[fa] + 1; int mx = 0;
	for (auto t: v[x]) Mx(mx, dep[t]);
	change(T[x], 0, n, mx);
	for (int i = h[x]; i; i = ne[i]) {
		int y = to[i]; if (y == fa) continue;
		dfs(y, x); 
		ll S = query(T[y], 0, n, dep[x]), SS = 0;
		T[x] = merge(T[x], T[y], 0, n, S, SS);
	}
}

int main() {
//	freopen ("destiny.in","r",stdin);
//	freopen ("destiny.out","w",stdout);
	read(n);
	for (int i = 1, x, y;i < n; i++) 
		read(x), read(y), adde(x, y), adde(y, x);
	read(m);
	for (int i = 1, x, y;i <= m; i++) 
		read(x), read(y), v[y].push_back(x);
	dfs(1, 0), write(query(T[1], 0, n, 0));
	return 0;
}
posted @ 2020-08-19 18:11  Hs-black  阅读(381)  评论(2编辑  收藏  举报