P5298 [PKUWC2018] Minimax
线段树好题,虽说是看了题解的,但是细节还是不少。
因为叶子结点的权值互不相同,所以可以离散化,设所有叶子节点的权值不同种类的个数为 \(m\)。
设计状态 \(f_{i,j}\) 为点 \(i\) 取 \(j\) 的概率。
若点 \(i\) 为叶子节点,则 \(f_{i,w_i}\gets1\)。
若点 \(i\) 只有一个儿子,则 \(f_{i,j}\gets f_{son,j}\)。
若点 \(i\) 有两个儿子 \(ls,rs\),左儿子贡献为
\[f_{ls,j}\times(p_i\sum_{k=1}^{j-1}f_{rs,k}+(1-p_i)\sum_{k=j+1}^{m}f_{rs,k})
\]
上式没什么好解释的,括号内前半段为点 \(i\) 取最大值,后半段为点 \(i\) 取最小值。
同理我们就得到
\[f_{i,j}\gets f_{ls,j}\times(p_i\sum_{k=1}^{j-1}f_{rs,k}+(1-p_i)\sum_{k=j+1}^{m}f_{rs,k})+f_{rs,j}\times(p_i\sum_{k=1}^{j-1}f_{ls,k}+(1-p_i)\sum_{k=j+1}^{m}f_{ls,k})
\]
注意到式子中有四个前后缀和的形式,建立权值线段树,值域为可能取到的值,维护的为概率,就能用线段树合并的方式进行转移。
具体就是假设我们当前要合并 \(x,y\) 的子树,以 \(x\) 为例,我们需要维护的是当前区间的前面的和(\([1,l)\))及当前区间的后面的和(\((r,m]\))。然后若 \(x,y\) 中有一个点为空,说明概率为 \(0\),上式便缺少了一项,变成了乘法的形式,那么就可以打上标记,进行维护了。
看看代码
int merge_(int x, int y, int prex, int sufx, int prey, int sufy) {
if (!x && !y) return 0;
if (!y) return pushmul(x, (prey*wx+sufy*(mod+1-wx))%mod), x;
if (!x) return pushmul(y, (prex*wx+sufx*(mod+1-wx))%mod), y;
pushdown(x), pushdown(y);
int a = (sufx+sum[son[x][1]])%mod, b = (sufy+sum[son[y][1]])%mod;
int c = (prex+sum[son[x][0]])%mod, d = (prey+sum[son[y][0]])%mod;
son[x][0] = merge_(son[x][0], son[y][0], prex, a, prey, b);
son[x][1] = merge_(son[x][1], son[y][1], c, sufx, d, sufy);
return pushup(x), del(y), x;
}
这里名字取的是 merge_
,原本取得 merge
会报错,调了半天。
细节还有就是记得 pushdown
。
代码(调了好久qwq)
#include <bits/stdc++.h>
#define int long long
#define ls son[p][0]
#define rs son[p][1]
using namespace std;
const int N = 3e5+5, mod = 998244353, inv10000 = 796898467;
int n, wx, tot, cnt, ans, delcnt, w[N], tw[N], lc[N], rc[N], rt[N], gb[N*20], sum[N*20], mul[N*20], son[N*20][2];
int newnode() {
int rt = delcnt ? gb[delcnt--] : ++cnt;
sum[rt] = mul[rt] = 1;
return rt;
}
void del(int p) { gb[++delcnt] = p; ls = rs = sum[p] = mul[p] = 0; }
void pushup(int p) { sum[p] = (sum[ls]+sum[rs])%mod; }
void pushmul(int p, int x) { sum[p] = sum[p]*x%mod, mul[p] = mul[p]*x%mod; }
void pushdown(int p) {
if (mul[p] == 1) return;
if (ls) pushmul(ls, mul[p]);
if (rs) pushmul(rs, mul[p]);
mul[p] = 1;
}
void modify(int &p, int l, int r, int x) {
if (!p) p = newnode();
if (l == r) return;
int mid = (l+r)>>1;
if (x <= mid) modify(ls, l, mid, x);
else modify(rs, mid+1, r, x);
}
int merge_(int x, int y, int prex, int sufx, int prey, int sufy) {
if (!x && !y) return 0;
if (!y) return pushmul(x, (prey*wx+sufy*(mod+1-wx))%mod), x;
if (!x) return pushmul(y, (prex*wx+sufx*(mod+1-wx))%mod), y;
pushdown(x), pushdown(y);
int a = (sufx+sum[son[x][1]])%mod, b = (sufy+sum[son[y][1]])%mod;
int c = (prex+sum[son[x][0]])%mod, d = (prey+sum[son[y][0]])%mod;
son[x][0] = merge_(son[x][0], son[y][0], prex, a, prey, b);
son[x][1] = merge_(son[x][1], son[y][1], c, sufx, d, sufy);
return pushup(x), del(y), x;
}
void dfs(int x) {
if (!lc[x]) modify(rt[x], 1, tot, w[x]);
else if (!rc[x]) dfs(lc[x]), rt[x] = rt[lc[x]];
else dfs(lc[x]), dfs(rc[x]), wx = w[x], rt[x] = merge_(rt[lc[x]], rt[rc[x]], 0, 0, 0, 0);
}
void solve(int p, int l, int r) {
if (l == r) {
ans = (ans+l*tw[l]%mod*sum[p]%mod*sum[p])%mod;
return;
}
pushdown(p);
int mid = l+r>>1;
if (ls) solve(ls, l, mid);
if (rs) solve(rs, mid+1, r);
}
signed main() {
ios::sync_with_stdio(0);
cin >> n;
for (int i = 1, fa; i <= n; ++i) {
cin >> fa;
if (i > 1) lc[fa] ? rc[fa] = i : lc[fa] = i;
}
for (int i = 1; i <= n; ++i) {
cin >> w[i];
if (lc[i]) w[i] = w[i]*inv10000%mod;
else tw[++tot] = w[i];
}
sort(tw+1, tw+tot+1);
for (int i = 1; i <= n; ++i) if (!lc[i]) w[i] = lower_bound(tw+1, tw+tot+1, w[i])-tw;
dfs(1);
solve(rt[1], 1, tot);
cout << ans;
return 0;
}