P6773 [NOI2020] 命运

标签:DS \(B\) | DP \(B\)

题意 给定一棵 $n 个点的树和 $m$ 条祖先到儿子的路径,求给边黑白染色的方案数使得每条路径上至少有一个黑点。

\(n,m \leq 5 \times 10^5\)

你考虑显然如果一个子树内有路径没有被消除,并且延伸到子树外,那么我们显然只关心深度最深的那个点。

那么我们设 \(f_{u,j}\) 表示子树 \(u\) 中没有被消除过的路径中祖先节点最深的深度为 \(j\) 时的方案数。

我们枚举一条边是否染成黑色,可以得到转移如下:

\[f_{u,i} = \sum_{j=0}^{dep_u} f_{u,i}f_{v,j}+\sum_{j=0}^i f_{u,i}f_{v,j} + \sum_{j=0}^{i-1} f_{u,j}f_{v,i} \]

显然我们可以将与 \(j\) 无关的项提出来,就变成了下式:

\[f_{u,i} = f_{u,i}(\sum_{j=0}^{dep_u} f_{v,j}+\sum_{j=0}^i f_{v,j}) + f_{v,i}\sum_{j=0}^{i-1} f_{u,j} \]

我们设 \(sum_{i,j} = \sum_{k = 0}^{j} f_{i,k}\) 那么就可以得到一个 \(n^2\) 的转移:

\[f_{u,i} = f_{u,i}(sum_{v,dep_u}+sum_{v,i}) + f_{v,i}sum_{u,i-1} \]

我们考虑因为只有 \(m\) 个询问,所以说 \(dp\) 时很多状态和转移都是没有什么必要的,但是我们也不能直接优化掉转移。

考虑对于这种类型的 \(dp\) 我们常见的就是用 set+启发式合并线段树合并长链剖分 等方式进行优化。

显然第一种是 \(O(n\log^2 n)\) 的,不可接受,所以我们使用 线段树合并 解决这个问题。

考虑 \(sum_{v,dep_u}\) 可以提前求出,而 \(sum_{v,i}\),\(sum_{u,i-1}\) 可以在线段树的过程中动态增加,至于和对应权值的乘法就直接打乘法标记即可,合并的时候会将两个部分的答案加起来。

code:

#include<bits/stdc++.h>
using namespace std;
const int NN = 5e5 + 8,MOD = 998244353;
typedef long long ll;
int n,m;

inline int read(){
	register char c = getchar();
	register int res = 0;
	while(!isdigit(c)) c = getchar();
	while(isdigit(c)) res = res * 10 + c - '0', c = getchar();
	return res;
}

struct Edge{
	int to,next;
}edge[NN << 1];
int head[NN],cnt;
void init(){
	memset(head,-1,sizeof(head));
	cnt = 1;
}
void add_edge(int u,int v){
	edge[++cnt] = {v,head[u]};
	head[u] = cnt;
}

struct Seg{
	int ls,rs;
	ll num,mul;
	#define ls(x) tree[x].ls
	#define rs(x) tree[x].rs
	#define num(x) tree[x].num
	#define mul(x) tree[x].mul
}tree[NN << 5];
int nodecnt;
void addlz(int x,ll num){
	if(!x) return;
	num(x) = num(x) * num % MOD;
	mul(x) = mul(x) * num % MOD;
}
void pushup(int x){
	num(x) = (num(ls(x)) + num(rs(x))) % MOD;
}
void pushdown(int x){
	addlz(ls(x),mul(x));
	addlz(rs(x),mul(x));
	mul(x) = 1;
}
void build(int &x,int l,int r,int pos){
	if(!x) x = ++nodecnt;
	num(x) = mul(x) = 1;
	if(l == r) return;
	int mid = (l + r) / 2;
	if(pos <= mid) build(ls(x),l,mid,pos);
	else build(rs(x),mid + 1,r,pos);
}
ll query(int x,int l,int r,int pos){
	if(!x || r <= pos) return num(x);
	int mid = (l + r) / 2;
	ll res = 0;
	pushdown(x);
	if(pos <= mid) return query(ls(x),l,mid,pos);
	else return (num(ls(x)) + query(rs(x),mid+1,r,pos)) % MOD;
}
// s1 -> (sum[y][dep[x]]+sum[y][i]), s2 -> sum[x][i-1]
int merge(int x,int y,int l,int r,ll &s1,ll &s2){
	if(!x && !y) return 0;
	if(!x || !y){
		if(y){
			s1 = (s1 + num(y)) % MOD;
			addlz(y,s2);
			return y;
		}
		s2 = (s2 + num(x)) % MOD;
		addlz(x,s1);
		return x;
	}
	if(l == r){
		ll tx = num(x), ty = num(y);
		s1 = (s1 + ty) % MOD; 
		num(x) = (num(x) * s1 + num(y) * s2) % MOD;
		s2 = (s2 + tx) % MOD;
		return x;
	}
	pushdown(x),pushdown(y);
	int mid = (l + r) / 2;
	ls(x) = merge(ls(x),ls(y),l,mid,s1,s2);
	rs(x) = merge(rs(x),rs(y),mid+1,r,s1,s2);
	pushup(x);
	return x;
}

vector<int> Q[NN];

int dep[NN];
int rt[NN];
void dfs(int u,int fa){
	dep[u] = dep[fa] + 1;
	int mx = 0;
	for(int i : Q[u]) mx = max(mx,dep[i]);
	build(rt[u],0,n,mx);
	
	for(int i = head[u]; i != -1; i = edge[i].next){
		int v = edge[i].to;
		if(v == fa) continue;
		dfs(v,u);
		ll S1 = query(rt[v],0,n,dep[u]),S2 = 0;
//		printf("%lld\n",S1);
		rt[u] = merge(rt[u],rt[v],0,n,S1,S2);
	}
}

int main(){
	n = read();
	init();
	for(int i = 1,u,v; i < n; ++i){
		u = read();v = read();
		add_edge(u,v);add_edge(v,u);
	}
	m = read();
	for(int i = 1,u,v; i <= m; ++i){
		u = read();v = read();
		Q[v].push_back(u);
	}
	dfs(1,0);
	printf("%lld",query(rt[1],0,n,0));
}
posted @ 2024-02-05 20:53  ricky_lin  阅读(22)  评论(0)    收藏  举报