P6773 [NOI2020] 命运
题意
给定一棵 $n 个点的树和 $m$ 条祖先到儿子的路径,求给边黑白染色的方案数使得每条路径上至少有一个黑点。\(n,m \leq 5 \times 10^5\)
你考虑显然如果一个子树内有路径没有被消除,并且延伸到子树外,那么我们显然只关心深度最深的那个点。
那么我们设 \(f_{u,j}\) 表示子树 \(u\) 中没有被消除过的路径中祖先节点最深的深度为 \(j\) 时的方案数。
我们枚举一条边是否染成黑色,可以得到转移如下:
\[f_{u,i} = \sum_{j=0}^{dep_u} f_{u,i}f_{v,j}+\sum_{j=0}^i f_{u,i}f_{v,j} + \sum_{j=0}^{i-1} f_{u,j}f_{v,i}
\]
显然我们可以将与 \(j\) 无关的项提出来,就变成了下式:
\[f_{u,i} = f_{u,i}(\sum_{j=0}^{dep_u} f_{v,j}+\sum_{j=0}^i f_{v,j}) + f_{v,i}\sum_{j=0}^{i-1} f_{u,j}
\]
我们设 \(sum_{i,j} = \sum_{k = 0}^{j} f_{i,k}\) 那么就可以得到一个 \(n^2\) 的转移:
\[f_{u,i} = f_{u,i}(sum_{v,dep_u}+sum_{v,i}) + f_{v,i}sum_{u,i-1}
\]
我们考虑因为只有 \(m\) 个询问,所以说 \(dp\) 时很多状态和转移都是没有什么必要的,但是我们也不能直接优化掉转移。
考虑对于这种类型的 \(dp\) 我们常见的就是用 set+启发式合并、线段树合并、长链剖分 等方式进行优化。
显然第一种是 \(O(n\log^2 n)\) 的,不可接受,所以我们使用 线段树合并 解决这个问题。
考虑 \(sum_{v,dep_u}\) 可以提前求出,而 \(sum_{v,i}\),\(sum_{u,i-1}\) 可以在线段树的过程中动态增加,至于和对应权值的乘法就直接打乘法标记即可,合并的时候会将两个部分的答案加起来。
code:
#include<bits/stdc++.h>
using namespace std;
const int NN = 5e5 + 8,MOD = 998244353;
typedef long long ll;
int n,m;
inline int read(){
register char c = getchar();
register int res = 0;
while(!isdigit(c)) c = getchar();
while(isdigit(c)) res = res * 10 + c - '0', c = getchar();
return res;
}
struct Edge{
int to,next;
}edge[NN << 1];
int head[NN],cnt;
void init(){
memset(head,-1,sizeof(head));
cnt = 1;
}
void add_edge(int u,int v){
edge[++cnt] = {v,head[u]};
head[u] = cnt;
}
struct Seg{
int ls,rs;
ll num,mul;
#define ls(x) tree[x].ls
#define rs(x) tree[x].rs
#define num(x) tree[x].num
#define mul(x) tree[x].mul
}tree[NN << 5];
int nodecnt;
void addlz(int x,ll num){
if(!x) return;
num(x) = num(x) * num % MOD;
mul(x) = mul(x) * num % MOD;
}
void pushup(int x){
num(x) = (num(ls(x)) + num(rs(x))) % MOD;
}
void pushdown(int x){
addlz(ls(x),mul(x));
addlz(rs(x),mul(x));
mul(x) = 1;
}
void build(int &x,int l,int r,int pos){
if(!x) x = ++nodecnt;
num(x) = mul(x) = 1;
if(l == r) return;
int mid = (l + r) / 2;
if(pos <= mid) build(ls(x),l,mid,pos);
else build(rs(x),mid + 1,r,pos);
}
ll query(int x,int l,int r,int pos){
if(!x || r <= pos) return num(x);
int mid = (l + r) / 2;
ll res = 0;
pushdown(x);
if(pos <= mid) return query(ls(x),l,mid,pos);
else return (num(ls(x)) + query(rs(x),mid+1,r,pos)) % MOD;
}
// s1 -> (sum[y][dep[x]]+sum[y][i]), s2 -> sum[x][i-1]
int merge(int x,int y,int l,int r,ll &s1,ll &s2){
if(!x && !y) return 0;
if(!x || !y){
if(y){
s1 = (s1 + num(y)) % MOD;
addlz(y,s2);
return y;
}
s2 = (s2 + num(x)) % MOD;
addlz(x,s1);
return x;
}
if(l == r){
ll tx = num(x), ty = num(y);
s1 = (s1 + ty) % MOD;
num(x) = (num(x) * s1 + num(y) * s2) % MOD;
s2 = (s2 + tx) % MOD;
return x;
}
pushdown(x),pushdown(y);
int mid = (l + r) / 2;
ls(x) = merge(ls(x),ls(y),l,mid,s1,s2);
rs(x) = merge(rs(x),rs(y),mid+1,r,s1,s2);
pushup(x);
return x;
}
vector<int> Q[NN];
int dep[NN];
int rt[NN];
void dfs(int u,int fa){
dep[u] = dep[fa] + 1;
int mx = 0;
for(int i : Q[u]) mx = max(mx,dep[i]);
build(rt[u],0,n,mx);
for(int i = head[u]; i != -1; i = edge[i].next){
int v = edge[i].to;
if(v == fa) continue;
dfs(v,u);
ll S1 = query(rt[v],0,n,dep[u]),S2 = 0;
// printf("%lld\n",S1);
rt[u] = merge(rt[u],rt[v],0,n,S1,S2);
}
}
int main(){
n = read();
init();
for(int i = 1,u,v; i < n; ++i){
u = read();v = read();
add_edge(u,v);add_edge(v,u);
}
m = read();
for(int i = 1,u,v; i <= m; ++i){
u = read();v = read();
Q[v].push_back(u);
}
dfs(1,0);
printf("%lld",query(rt[1],0,n,0));
}
本文来自博客园,作者:ricky_lin,转载请注明原文链接:https://www.cnblogs.com/rickylin/p/-/solution_P6773

浙公网安备 33010602011771号