P5298 [PKUWC2018]Minimax
P5298 [PKUWC2018]Minimax
首先考虑最简单的 \(\text{dp}\) 式子。
设 \(dp_{x,j}\) 表示当前在点 \(x\),且点 \(x\) 的权值为 \(j\) 的概率。由于 \(n \le 3\times 10^5\),考虑将题目给出的权值离散化。
由于一个点只有两个儿子,考虑点的转移方程。设 \(f_j\) 表示 \(x\) 左儿子的 dp 值,即 \(dp_{ls_x,j}\),\(g_j\) 表示 \(x\) 右儿子的 \(\text{dp}\) 值。(由于保证每个节点的权值互不相同,所以不需要考虑两个儿子 \(\text{dp}\) 值相同的情况)
-
若当前点的权值从左儿子转移
- 若当前权值是左儿子的最大值,则概率为 \(f_j \times p \times \sum_{i=1}^{j-1} g_i\)。
- 若当前权值是左儿子的最小值,则概率为 \(f_j\times (1-p)\times \sum_{i=j+1}^{Max} g_i\)。
-
若当前点的权值从右儿子转移
-
若当前权值是右儿子的最大值,则概率为 \(g_j \times p \times \sum_{i=1}^{j-1} f_i\)。
-
若当前权值是右儿子的最小值,则概率为 \(g_j\times (1-p)\times \sum_{i=j+1}^{Max} f_i\)。
-
综上,将上式全部相加即可得到当前点的转移方程
\[dp_{x,j}=f_j \times (p \times \sum_{i=1}^{j-1} g_i+(1-p)\times \sum_{i=j+1}^{Max} g_i)+g_j \times (p \times \sum_{i=1}^{j-1} f_i+(1-p)\times \sum_{i=j+1}^{Max} f_i)
\]
直接转移可以得到 \(\mathcal{O}(n^2)\) 的做法。
考虑优化,注意到式子中有多个前缀后缀和的形式,考虑使用权值线段树维护 \(dp\) 的第二维 \(j\),向上转移时将线段树合并。
当将以 \(x,y\) 为根的两颗线段树合并时,设当前区间为 \([l,r]\),在 \(\text{merge}\) 过程中记录
\[lx=\sum_{i=1}^{l-1} dp_{x,i}\\
rx=\sum_{i=r+1}^{Max} dp_{x,i}\\
ly=\sum_{i=1}^{l-1} dp_{y,i}\\
ry=\sum_{i=r+1}^{Max} dp_{y,i}\\
\]
分以下情况讨论:
- 当 \(x,y\) 均为空时,直接返回即可。
- 当 \(x\) 为空,\(y\) 不为空时,则 \(\forall i(l\le i\le n,i \in N)\),均有 \(dp_{x,i}=0\)。此时,上面的转移方程可以化为 \(dp_{x,j}=g_j \times (p \times \sum_{i=1}^{j-1} f_i+(1-p)\times \sum_{i=j+1}^{Max} f_i)\),则对于区间 \([l,r]\) 中的任何一个 \(j\),均满足 \(f_{l},f_{l+1},\dots,f_j,\dots,f_{r-1},f_r=0\),则上式又可以化为 \(dp_{x,j}=g_j \times (p \times \sum_{i=1}^{l-1} f_i+(1-p)\times \sum_{i=r+1}^{Max} f_i)\)。使用 \(lx,rx\) 表示则可以得到 \(dp_{x,j}=g_j \times (p \times lx+(1-p)\times rx)\),即右侧与 \(j\) 无关。则对于所有的 \(j\) 满足 \(j \in [l,r]\),都相当于在原来 \(y\) 树的基础上乘了 \((p \times lx+(1-p)\times rx)\)。直接转移打标记即可。
- 当 \(y\) 为空,\(x\) 不为空时,与上面一种情况类似。
- 当 \(x,y\) 均不为空时,维护接下来的 \(lx^{\prime},rx^{\prime},ly^{\prime},ry^{\prime}\) 并向两边同时递归即可。
综上,可以得到时间复杂度为 \(\mathcal{O}(n \log n)\),空间复杂度为 \(\mathcal{O}(n \log n)\) 的做法。
code
#include<bits/stdc++.h>
using namespace std;
namespace IO{
template<typename T>inline bool read(T &x){
x=0;
char ch=getchar();
bool flag=0,ret=0;
while(ch<'0'||ch>'9') flag=flag||(ch=='-'),ch=getchar();
while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar(),ret=1;
x=flag?-x:x;
return ret;
}
template<typename T,typename ...Args>inline bool read(T& a,Args& ...args){
return read(a)&&read(args...);
}
template<typename T>void prt(T x){
if(x>9) prt(x/10);
putchar(x%10+'0');
}
template<typename T>inline void put(T x){
if(x<0) putchar('-'),x=-x;
prt(x);
}
template<typename T>inline void put(char ch,T x){
if(x<0) putchar('-'),x=-x;
prt(x);
putchar(ch);
}
template<typename T,typename ...Args>inline void put(T a,Args ...args){
put(a);
put(args...);
}
template<typename T,typename ...Args>inline void put(const char ch,T a,Args ...args){
put(ch,a);
put(ch,args...);
}
inline void put(string s){
for(int i=0,sz=s.length();i<sz;i++) putchar(s[i]);
}
inline void put(const char* s){
for(int i=0,sz=strlen(s);i<sz;i++) putchar(s[i]);
}
}
using namespace IO;
#define N 300005
#define mod 998244353
#define ll long long
inline int power(int x,int y){
int res=1;
while(y){
if(y&1) res=(ll)res*x%mod;
x=(ll)x*x%mod;
y>>=1;
}
return res;
}
int n,son[N][2],num[N],value[N],p[N],b[N],Idx,rt[N],res[N],ans,idx;
struct node{
int ls,rs,sum,tag;
}t[N*25];
#define lc(x) t[x].ls
#define rc(x) t[x].rs
inline void push_down(int x){
if(t[x].tag==1) return;
t[lc(x)].sum=(ll)t[lc(x)].sum*t[x].tag%mod;
t[rc(x)].sum=(ll)t[rc(x)].sum*t[x].tag%mod;
t[lc(x)].tag=(ll)t[lc(x)].tag*t[x].tag%mod;
t[rc(x)].tag=(ll)t[rc(x)].tag*t[x].tag%mod;
t[x].tag=1;
}
inline void push_up(int x){
t[x].sum=(t[lc(x)].sum+t[rc(x)].sum)%mod;
}
inline void update(int &x,int l,int r,int pos,int pro){
if(!x) t[x=++idx].tag=1;
if(l==r) return t[x].sum=pro,void();
int mid=l+r>>1;
if(pos<=mid) update(lc(x),l,mid,pos,pro);
else update(rc(x),mid+1,r,pos,pro);
push_up(x);
}
inline int merge(int x,int y,int lx,int rx,int ly,int ry,int pro){
if(!x&&!y) return 0;
push_down(x),push_down(y);
int xmul=((ll)pro*ly%mod+(ll)(1-pro+mod)*ry%mod)%mod;
int ymul=((ll)pro*lx%mod+(ll)(1-pro+mod)*rx%mod)%mod;
if(!x){
t[y].sum=(ll)t[y].sum*ymul%mod;
t[y].tag=(ll)t[y].tag*ymul%mod;
return y;
}
if(!y){
t[x].sum=(ll)t[x].sum*xmul%mod;
t[x].tag=(ll)t[x].tag*xmul%mod;
return x;
}
int ax=t[lc(x)].sum,bx=t[rc(x)].sum,ay=t[lc(y)].sum,by=t[rc(y)].sum;
lc(x)=merge(lc(x),lc(y),lx,(rx+bx)%mod,ly,(ry+by)%mod,pro);
rc(x)=merge(rc(x),rc(y),(lx+ax)%mod,rx,(ly+ay)%mod,ry,pro);
push_up(x);
return x;
}
void dfs(int x){
if(!num[x]) update(rt[x],1,Idx,value[x],1);
else if(num[x]==1) dfs(son[x][0]),rt[x]=rt[son[x][0]];
else{
dfs(son[x][0]),dfs(son[x][1]);
rt[x]=merge(rt[son[x][0]],rt[son[x][1]],0,0,0,0,p[x]);
}
}
inline void getans(int x,int l,int r){
if(!x) return;
if(l==r) return res[l]=t[x].sum,void();
int mid=l+r>>1;push_down(x);
getans(lc(x),l,mid),getans(rc(x),mid+1,r);
}
int main(){
read(n);
for(int i=1,x;i<=n;i++){
read(x);
if(i==1) continue;
son[x][son[x][0]!=0]=i;
num[x]++;
}
for(int i=1,x;i<=n;i++){
read(x);
if(!num[i]) value[i]=x,b[++Idx]=x;
else p[i]=(ll)x*power(10000,mod-2)%mod;
}
sort(b+1,b+Idx+1);
Idx=unique(b+1,b+Idx+1)-b-1;
for(int i=1;i<=n;i++)
if(value[i]) value[i]=lower_bound(b+1,b+Idx+1,value[i])-b;
dfs(1);
getans(rt[1],1,Idx);
for(int i=1;i<=Idx;i++)
ans=(ans+(ll)i*b[i]%mod*res[i]%mod*res[i]%mod)%mod;
put('\n',ans);
return 0;
}

浙公网安备 33010602011771号