题目链接

https://www.lydsy.com/JudgeOnline/problem.php?id=5461

题解

线段树合并,线段树每个区间[l,r][l,r]代表取到第ll小到第rr小的权值的概率,对于每一个节点,线段树由两个端点合并,容易发现在点uu,对于第ii小的权值,假设这个权值是由左儿子贡献而来,取到这个权值的概率是
fu,i=fls,i(j<ipifrs,j+j>i(1pi)frs,j) f_{u,i}=f_{ls,i}(\sum_{j<i}p_if_{rs,j}+\sum_{j>i}(1-p_i)f_{rs,j})
右儿子同理。

代码

#include <cstdio>
#include <algorithm>
 
int read()
{
  int x=0,f=1;
  char ch=getchar();
  while((ch<'0')||(ch>'9'))
    {
      if(ch=='-')
        {
          f=-f;
        }
      ch=getchar();
    }
  while((ch>='0')&&(ch<='9'))
    {
      x=x*10+ch-'0';
      ch=getchar();
    }
  return x*f;
}
 
const int maxn=300000;
const int mod=998244353;
const int inv=796898467;
 
struct node
{
  node *son[2];
  int sum,tag;
};
 
node tree[maxn<<5];
int cnt;
 
int clear(node *x)
{
  x->son[0]=x->son[1]=NULL;
  x->sum=0;
  x->tag=1;
  return 0;
}
 
int puttag(node *x,int v)
{
  if(x==NULL)
    {
      return 0;
    }
  x->tag=1ll*x->tag*v%mod;
  x->sum=1ll*x->sum*v%mod;
  return 0;
}
 
int pushdown(node *x)
{
  puttag(x->son[0],x->tag);
  puttag(x->son[1],x->tag);
  x->tag=1;
  return 0;
}
 
int getsum(node *x)
{
  return (x==NULL)?0:x->sum;
}
 
int updata(node *x)
{
  x->sum=getsum(x->son[0])+getsum(x->son[1]);
  if(x->sum>=mod)
    {
      x->sum-=mod;
    }
  return 0;
}
 
node *merge(node *x,node *y,int l,int r,int p,int xl,int yl,int xr,int yr)
{
  if(x==NULL)
    {
      puttag(y,(1ll*p*xl+1ll*(mod+1-p)*xr)%mod);
      return y;
    }
  else if(y==NULL)
    {
      puttag(x,(1ll*p*yl+1ll*(mod+1-p)*yr)%mod);
      return x;
    }
  node *now=&tree[++cnt];
  clear(now);
  pushdown(x);
  pushdown(y);
  int xlp=xl+getsum(x->son[0]),ylp=yl+getsum(y->son[0]),xrp=xr+getsum(x->son[1]),yrp=yr+getsum(y->son[1]);
  if(xlp>=mod)
    {
      xlp-=mod;
    }
  if(ylp>=mod)
    {
      ylp-=mod;
    }
  if(xrp>=mod)
    {
      xrp-=mod;
    }
  if(yrp>=mod)
    {
      yrp-=mod;
    }
  now->son[0]=merge(x->son[0],y->son[0],l,r,p,xl,yl,xrp,yrp);
  now->son[1]=merge(x->son[1],y->son[1],l,r,p,xlp,ylp,xr,yr);
  updata(now);
  return now;
}
 
node *add(node *x,int l,int r,int pos,int v)
{
  if(x==NULL)
    {
      x=&tree[++cnt];
      clear(x);
    }
  if(l==r)
    {
      x->sum+=v;
      return x;
    }
  pushdown(x);
  int mid=(l+r)>>1;
  if(pos<=mid)
    {
      x->son[0]=add(x->son[0],l,mid,pos,v);
    }
  else
    {
      x->son[1]=add(x->son[1],mid+1,r,pos,v);
    }
  updata(x);
  return x;
}
 
int getsum(node *x,int l,int r,int pos)
{
  if(x==NULL)
    {
      return 0;
    }
  if(l==r)
    {
      return getsum(x);
    }
  pushdown(x);
  int mid=(l+r)>>1;
  if(pos<=mid)
    {
      return getsum(x->son[0],l,mid,pos);
    }
  else
    {
      return getsum(x->son[1],mid+1,r,pos);
    }
}
 
int pre[maxn+10],now[maxn+10],son[maxn+10],tot,p[maxn+10],n,top;
node *root[maxn+10];
 
int ins(int a,int b)
{
  pre[++tot]=now[a];
  now[a]=tot;
  son[tot]=b;
  return 0;
}
 
int search(int u)
{
  if(!now[u])
    {
      root[u]=add(root[u],1,top,p[u],1);
      return 0;
    }
  node *ls=NULL,*rs=NULL;
  for(int i=now[u]; i; i=pre[i])
    {
      int v=son[i];
      search(v);
      if(ls==NULL)
        {
          ls=root[v];
        }
      else
        {
          rs=root[v];
        }
    }
  if(rs==NULL)
    {
      root[u]=ls;
    }
  else
    {
      root[u]=merge(ls,rs,1,top,p[u],0,0,0,0);
    }
  return 0;
}
 
struct data
{
  int id,val;
 
  data(int _id=0,int _val=0):id(_id),val(_val){}
 
  bool operator <(const data &other) const
  {
    return val<other.val;
  }
};
 
data d[maxn+10];
 
int main()
{
  n=read();
  for(int i=1; i<=n; ++i)
    {
      int f=read();
      if(f)
        {
          ins(f,i);
        }
    }
  for(int i=1; i<=n; ++i)
    {
      p[i]=read();
      if(!now[i])
        {
          d[++top]=data(i,p[i]);
        }
      else
        {
          p[i]=1ll*p[i]*inv%mod;
        }
    }
  std::sort(d+1,d+top+1);
  for(int i=1; i<=top; ++i)
    {
      p[d[i].id]=i;
    }
  search(1);
  int ans=0;
  for(int i=1; i<=top; ++i)
    {
      int di=getsum(root[1],1,top,i);
      ans=(ans+1ll*i*d[i].val%mod*di%mod*di)%mod;
    }
  printf("%d\n",ans);
  return 0;
}