Loading

[PKUWC2018]Minimax(整体DP+线段树合并+概率)

洛谷题目传送门

题目说每一个权值的概率都不为0,所以每个权值都有可能取到
考虑如下的\(dp\):
\(f_{i,j}\)表示点\(i\)取到\(j\)权值的概率,这里的权值可以离散化一下
那么这个\(dp\)是怎么转移的?
·如果这个点是叶子节点,那么只有一种取值是1,其余取值是0,
·如果这个点只有一个儿子,那么整个\(dp\)数组都由那个儿子继承过来
·如果这个点有两个儿子,设为\(x,y\)
那么如果从左儿子转移

\[f_{i,j}=f_{x,a}\times (p_i \times \sum_{b=1}^{a-1}f_{y,b}+(1-p_i)\times \sum _{b=a+1}^n f_{y,b}) \]

又儿子类似可以发现我们实际是将左儿子的某些元素乘上一些值,这样可以用一颗线段树,打乘法标记来维护
但是有两个儿子,考虑线段树合并
在线段树合并的过程中,如果一个节点在a,b中都有,那么可以递归解决
如果只有a有或只有b有
比如只有a有,当前区间是\(l,r\),那么整个当前区间都要乘上
\(p_i\times \sum_{b=1}^{l-1}f_{y,b}+(1-p_i)\times \sum _{b=r+1}^n f_{y,b}\)
其中\(p_i\)是已知量,其余的东西是一颗线段树的前/后缀和,可以线段树合并维护,然后把当前区间打上一个乘法tag,最后遍历整颗线段树一遍就是答案

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int mod = 998244353;
const int N = 3e5+7;
struct edge
{
	int y,next;
}e[2*N];
int link[N],t=0;
void add(int x,int y)
{
	e[++t].y=y;
	e[t].next=link[x];
	link[x]=t;
}
LL p[N];
int Tl[N],Tr[N];
int n;
LL Pow(int a,int b)
{
	LL res=1;
	while(b)
	{
		if(b&1) res=1ll*res*a%mod;
		a=1ll*a*a%mod;
		b>>=1;
	}
	return res;
}
LL Inv;
int m;
int val[N];
LL sum[N*10],tag[N*10];
int rot[N],tot=0,rson[N*10],lson[N*10];
void pushup(int k)
{
	sum[k]=(sum[lson[k]]+sum[rson[k]])%mod;
}
void pushdown(int k)
{
	if(tag[k]==1) return;
	tag[lson[k]]=1ll*tag[lson[k]]*tag[k]%mod;
	tag[rson[k]]=1ll*tag[rson[k]]*tag[k]%mod;
	sum[lson[k]]=1ll*sum[lson[k]]*tag[k]%mod;
	sum[rson[k]]=1ll*sum[rson[k]]*tag[k]%mod;
	tag[k]=1;
}
void Modify(int &k,int l,int r,int x,LL v)
{
	if(!k) k=++tot;
	tag[k]=1;
	if(l==r)
	{
		sum[k]=v;
		return;
	}
	int mid=(l+r)>>1;
	if(x<=mid) Modify(lson[k],l,mid,x,v);
	else Modify(rson[k],mid+1,r,x,v);
	pushup(k);
}
LL P,IP;
int Merge(int x,int y,int l,int r,LL sumx,LL sumy)
{
	if(!x&&!y) return 0;
	if(x) pushdown(x);
	if(y) pushdown(y);
	if(!y)
	{
		LL mul=(1ll*P*sumy%mod+1ll*IP*((1-sumy+mod)%mod)%mod)%mod;
		tag[x]=1ll*tag[x]*mul%mod;
		sum[x]=1ll*sum[x]*mul%mod;
		return x;
	}
	if(!x)
	{
		LL mul=(1ll*P*sumx%mod+1ll*IP*((1-sumx+mod)%mod)%mod)%mod;
		tag[y]=1ll*tag[y]*mul%mod;
		sum[y]=1ll*sum[y]*mul%mod;
		return y;
	}	
	int mid=(l+r)>>1;
	rson[x]=Merge(rson[x],rson[y],mid+1,r,(sumx+sum[lson[x]])%mod,(sumy+sum[lson[y]])%mod);	
	lson[x]=Merge(lson[x],lson[y],l,mid,sumx,sumy);
	pushup(x);
	return x;
}
void dfs(int x)
{
	if(!Tl[x]&&!Tr[x]) return;
	dfs(Tl[x]);
	if(!Tr[x])
	{
		rot[x]=rot[Tl[x]];
		return;
	}
	dfs(Tr[x]);
	P=p[x];
	IP=(1-P+mod)%mod;
	rot[x]=Merge(rot[Tl[x]],rot[Tr[x]],1,m,0,0);
}
bool cmp(int x,int y)
{
	return x<y;
}
int dct[N];
LL ans=0,num=0;
void Getans(int k,int l,int r)
{
	if(!k) return;
	if(l==r)
	{
		++num;
		ans=(ans+1ll*num*dct[num]%mod*sum[k]%mod*sum[k]%mod)%mod;
		return;	
	}
	pushdown(k);
	int mid=(l+r)>>1;
	Getans(lson[k],l,mid);
	Getans(rson[k],mid+1,r);
}
inline int read()
{
    int x=0,t=1;char ch=getchar();
    while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
    if(ch=='-')t=-1,ch=getchar();
    while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
    return x*t;
}
int main()
{
	cin>>n;
	Inv=Pow(10000,mod-2);
	for(int i=1;i<=n;i++)
	{
		int fa=read();
		if(fa==0) continue;
		if(!Tl[fa]) Tl[fa]=i;
		else Tr[fa]=i;
	}
	for(int i=1;i<=n;i++)
	{
		int w=read();
		if(Tl[i]) p[i]=1ll*w*Inv%mod;
		else val[i]=w,dct[++m]=val[i];
	}
	sort(dct+1,dct+m+1,cmp);
	for(int i=1;i<=n;i++)
	if(!Tl[i]) val[i]=lower_bound(dct+1,dct+m+1,val[i])-dct;
	for(int i=1;i<=n;i++)
	if(!Tl[i]) Modify(rot[i],1,m,val[i],1);
	dfs(1);
	Getans(rot[1],1,m); 
	cout<<ans;
	return 0;
} 
posted @ 2021-12-29 10:38  Larunatrecy  阅读(54)  评论(0)    收藏  举报