[PKUWC2018]Minimax
题意:
给 \(n(\leq3\times 10^5)\) 个结点的二叉树,叶子结点有互不相同的权值 \(v\), 非叶子节点有概率 \(p\) 取两个子节点最大值,否则取最小值。
求 \(\sum_{i = 1}^{m} i \cdot v_i \cdot d_i^2\), 即 \(m\) 种可能的结果中,第 \(i\) 小的权值 \(v_i\) * \(i\) * 根节点是 \(v_i\) 的概率的平方。
树形dp
由于权值互不相同,所有的可能个数就是叶子节点的个数,只要求出根结点是出现某个权值的概率, 问题就可以解决了。
因为需要比较大小,得确定当前节点的权值。
方程
可以设出方程 \(f_{u, x}\), 表示当前节点 \(u\) 的权值是 第 \(x\) 大的权值 \(val\) 的概率。
初始状态
对于每个叶子节点 \(u\),这个点是 \(v_{u}\) 的概率肯定是 \(1\)。
因此: \(f_{u, x} = 1\)。
转移
分类讨论当前非叶子节点 \(u\) 的权值 \(val\) 从哪里来,令 \(val\) 是第 \(x\) 大的权值。
-
如果只有一个儿子,那就一定是从这个儿子来:
-
考虑从左节点来,有最大值和最小值两种途径。
- 最大值:\[f_{u, x} = f_{l, x} \times \sum_{i = 1}^{x - 1} f_{r, i} \times p \]
- 最小值:\[f_{u, x} = f_{l, x} \times \sum_{i = x + 1}^{m} f_{r, i} \times (1 - p) \]
- 最大值:
对于右节点同理。
就有:
化简:
分析
时间复杂度 \(O(n \times m)\), 空间复杂度 \(O(n \times m)\)。
对于空间,只有根节点才有用,没必要开那么大,可以用到类似滚动数组的玩意,但要在树上?
对于时间,用前后缀和优化的转移是 \(O(m)\)的,还是不够好?
线段树合并
对于上述问题
线段树合并的空间是 \(O(m \log m)\) 的,因为加入新点的空间花费是 \(O(\log m)\), 最多插入 \(O(m)\) 次。
转移的形式像对应点相加,在乘一个值 \(P\), 简单表示为 \(f_{i} = f1_{i} \times P_1 + f2_i \times P_2\)
而这个 \(P1, P2\) 可以在合并的时候顺路求出, 只用维护区间和就能做到。
如果遇到某个儿子的结点为空,其中的 \(f_{l, x}\) 或 \(f_{r, x}\), 那么打一个乘法的懒标记即可。
线段树合并时间复杂度均摊是 \(O(n \log m)\) 的。
代码
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
const int MAXN = 300010;
const int INF = 0x7fffffff;
const int mod = 998244353;
template <typename T>
void Read(T &x) {
x = 0; T f = 1; char a = getchar();
for(; a < '0' || '9' < a; a = getchar()) if (a == '-') f = -f;
for(; '0' <= a && a <= '9'; a = getchar()) x = (x * 10) + (a ^ 48);
x *= f;
}
int add(int a, int b) {
int c = a + b;
if (c >= mod) c -= mod;
if (c < 0) c += mod;
return c;
}
int mul(int a, int b) {
return 1ll * a * b % mod;
}
int qpow(int a, int b) {
int sum(1);
while(b) {
if (b & 1) sum = mul(sum, a);
a = mul(a, a);
b >>= 1;
}
return sum;
}
int n;
vector<int> e[MAXN];
int val[MAXN];
int cnt;
int L[MAXN << 5], R[MAXN << 5], Mul[MAXN << 5], sum[MAXN << 5];
int newnode() {
int rt = ++ cnt;
L[rt] = R[rt] = sum[rt] = 0;
Mul[rt] = 1;
return rt;
}
void pushup(int rt) {
sum[rt] = add(sum[L[rt]], sum[R[rt]]);
}
void _mul(int rt, int val) {
if (!rt) return ;
sum[rt] = mul(sum[rt], val);
Mul[rt] = mul(Mul[rt], val);
}
void pushdown(int rt) {
if (!rt) return ;
if (L[rt]) _mul(L[rt], Mul[rt]);
if (R[rt]) _mul(R[rt], Mul[rt]);
Mul[rt] = 1;
}
int update(int _L, int _C, int l, int r, int rt) {
if (!rt) rt = newnode();
if (l == r) {
sum[rt] = _C;
return rt;
}
pushdown(rt);
int m = (l + r) >> 1;
if (_L <= m) L[rt] = update(_L, _C, l, m, L[rt]);
else R[rt]= update(_L, _C, m + 1, r, R[rt]);
pushup(rt);
return rt;
}
int merge(int x, int y, int sum1, int sum2, int p, int l, int r) {
if (!x && !y) return 0;
if (!x) {
_mul(y, sum2);
return y;
}
if (!y) {
_mul(x, sum1);
return x;
}
pushdown(x), pushdown(y);
int m = (l + r) >> 1;
int ls1 = sum[L[x]], ls2 = sum[L[y]], rs1 = sum[R[x]], rs2 = sum[R[y]];
L[x] = merge(L[x], L[y], add(sum1, mul(add(1, -p), rs2)), add(sum2, mul(add(1, -p), rs1)), p, l, m);
R[x] = merge(R[x], R[y], add(sum1, mul(p, ls2)), add(sum2, mul(p, ls1)), p, m + 1, r);
pushup(x);
return x;
}
int len;
int b[MAXN];
int root[MAXN];
void dfs(int u) {
if (!e[u].size()) {
root[u] = update(lower_bound(b + 1, b + len + 1, val[u]) - b, 1, 1, len, root[u]);
} else if(e[u].size() == 1) {
int son = e[u][0];
dfs(son);
root[u] = root[son];
} else {
int l = e[u][0], r = e[u][1];
dfs(l), dfs(r);
root[u] = merge(root[l], root[r], 0, 0, val[u], 1, len);
}
}
int query(int l, int r, int rt) {
if (!rt) return 0;
if (l == r) {
return mul(mul(l, b[l]), qpow(sum[rt], 2));
}
int m = (l + r) >> 1;
pushdown(rt);
return add(query(l, m, L[rt]), query(m + 1, r, R[rt]));
}
int main() {
Read(n);
for (int i = 1; i <= n; i ++) {
int fa;
Read(fa);
e[fa].emplace_back(i);
}
for (int i = 1; i <= n; i ++) {
Read(val[i]);
if (e[i].size())
val[i] = mul(val[i], qpow(10000, mod - 2));
else
b[++ len] = val[i];
}
sort(b + 1, b + len + 1);
len = unique(b + 1, b + len + 1) - b - 1;
dfs(1);
cout << query(1, len, root[1]);
return 0;
}

浙公网安备 33010602011771号