[PKUWC2018]Minimax(整体DP+线段树合并+概率)
洛谷题目传送门
题目说每一个权值的概率都不为0,所以每个权值都有可能取到
考虑如下的\(dp\):
设\(f_{i,j}\)表示点\(i\)取到\(j\)权值的概率,这里的权值可以离散化一下
那么这个\(dp\)是怎么转移的?
·如果这个点是叶子节点,那么只有一种取值是1,其余取值是0,
·如果这个点只有一个儿子,那么整个\(dp\)数组都由那个儿子继承过来
·如果这个点有两个儿子,设为\(x,y\)
那么如果从左儿子转移
则
\[f_{i,j}=f_{x,a}\times (p_i \times \sum_{b=1}^{a-1}f_{y,b}+(1-p_i)\times \sum _{b=a+1}^n f_{y,b})
\]
又儿子类似可以发现我们实际是将左儿子的某些元素乘上一些值,这样可以用一颗线段树,打乘法标记来维护
但是有两个儿子,考虑线段树合并
在线段树合并的过程中,如果一个节点在a,b中都有,那么可以递归解决
如果只有a有或只有b有
比如只有a有,当前区间是\(l,r\),那么整个当前区间都要乘上
\(p_i\times \sum_{b=1}^{l-1}f_{y,b}+(1-p_i)\times \sum _{b=r+1}^n f_{y,b}\)
其中\(p_i\)是已知量,其余的东西是一颗线段树的前/后缀和,可以线段树合并维护,然后把当前区间打上一个乘法tag,最后遍历整颗线段树一遍就是答案
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int mod = 998244353;
const int N = 3e5+7;
struct edge
{
int y,next;
}e[2*N];
int link[N],t=0;
void add(int x,int y)
{
e[++t].y=y;
e[t].next=link[x];
link[x]=t;
}
LL p[N];
int Tl[N],Tr[N];
int n;
LL Pow(int a,int b)
{
LL res=1;
while(b)
{
if(b&1) res=1ll*res*a%mod;
a=1ll*a*a%mod;
b>>=1;
}
return res;
}
LL Inv;
int m;
int val[N];
LL sum[N*10],tag[N*10];
int rot[N],tot=0,rson[N*10],lson[N*10];
void pushup(int k)
{
sum[k]=(sum[lson[k]]+sum[rson[k]])%mod;
}
void pushdown(int k)
{
if(tag[k]==1) return;
tag[lson[k]]=1ll*tag[lson[k]]*tag[k]%mod;
tag[rson[k]]=1ll*tag[rson[k]]*tag[k]%mod;
sum[lson[k]]=1ll*sum[lson[k]]*tag[k]%mod;
sum[rson[k]]=1ll*sum[rson[k]]*tag[k]%mod;
tag[k]=1;
}
void Modify(int &k,int l,int r,int x,LL v)
{
if(!k) k=++tot;
tag[k]=1;
if(l==r)
{
sum[k]=v;
return;
}
int mid=(l+r)>>1;
if(x<=mid) Modify(lson[k],l,mid,x,v);
else Modify(rson[k],mid+1,r,x,v);
pushup(k);
}
LL P,IP;
int Merge(int x,int y,int l,int r,LL sumx,LL sumy)
{
if(!x&&!y) return 0;
if(x) pushdown(x);
if(y) pushdown(y);
if(!y)
{
LL mul=(1ll*P*sumy%mod+1ll*IP*((1-sumy+mod)%mod)%mod)%mod;
tag[x]=1ll*tag[x]*mul%mod;
sum[x]=1ll*sum[x]*mul%mod;
return x;
}
if(!x)
{
LL mul=(1ll*P*sumx%mod+1ll*IP*((1-sumx+mod)%mod)%mod)%mod;
tag[y]=1ll*tag[y]*mul%mod;
sum[y]=1ll*sum[y]*mul%mod;
return y;
}
int mid=(l+r)>>1;
rson[x]=Merge(rson[x],rson[y],mid+1,r,(sumx+sum[lson[x]])%mod,(sumy+sum[lson[y]])%mod);
lson[x]=Merge(lson[x],lson[y],l,mid,sumx,sumy);
pushup(x);
return x;
}
void dfs(int x)
{
if(!Tl[x]&&!Tr[x]) return;
dfs(Tl[x]);
if(!Tr[x])
{
rot[x]=rot[Tl[x]];
return;
}
dfs(Tr[x]);
P=p[x];
IP=(1-P+mod)%mod;
rot[x]=Merge(rot[Tl[x]],rot[Tr[x]],1,m,0,0);
}
bool cmp(int x,int y)
{
return x<y;
}
int dct[N];
LL ans=0,num=0;
void Getans(int k,int l,int r)
{
if(!k) return;
if(l==r)
{
++num;
ans=(ans+1ll*num*dct[num]%mod*sum[k]%mod*sum[k]%mod)%mod;
return;
}
pushdown(k);
int mid=(l+r)>>1;
Getans(lson[k],l,mid);
Getans(rson[k],mid+1,r);
}
inline int read()
{
int x=0,t=1;char ch=getchar();
while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
if(ch=='-')t=-1,ch=getchar();
while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
return x*t;
}
int main()
{
cin>>n;
Inv=Pow(10000,mod-2);
for(int i=1;i<=n;i++)
{
int fa=read();
if(fa==0) continue;
if(!Tl[fa]) Tl[fa]=i;
else Tr[fa]=i;
}
for(int i=1;i<=n;i++)
{
int w=read();
if(Tl[i]) p[i]=1ll*w*Inv%mod;
else val[i]=w,dct[++m]=val[i];
}
sort(dct+1,dct+m+1,cmp);
for(int i=1;i<=n;i++)
if(!Tl[i]) val[i]=lower_bound(dct+1,dct+m+1,val[i])-dct;
for(int i=1;i<=n;i++)
if(!Tl[i]) Modify(rot[i],1,m,val[i],1);
dfs(1);
Getans(rot[1],1,m);
cout<<ans;
return 0;
}

浙公网安备 33010602011771号