【LOJ2537】Minimax(PKUWC2018)-树形DP+线段树合并

测试地址:Minimax
做法:本题需要用到树形DP+线段树合并。
很快想到一种定义状态的方式:f(i,j)表示点i的权值为j(离散化后)的概率,然后转移时,因为这是棵二叉树,令son为当前枚举的儿子,sum(j)为另一个儿子中权值小于j的概率,我们有以下状态转移方程:
f(i,j)=sonj=1nf(son,j)[pisum(j)+(1pi)(1sum(j))]
因为题目中说明了叶子节点权值各不相同,所以不用考虑重复的问题。那么我们就可以这样转移,时间复杂度为O(n2)
然而显然过不了,这怎么办呢?我们发现,上面的转移过程就是把两个序列拼在一起,然后再进行某些区间乘就是新的概率。我们知道求区间和和维护区间乘可以用线段树维护,那么这个“序列合并”的操作显然就可以用线段树合并来维护了,因为区间乘操作只会出现在两棵树中仅在一棵中存在的子树中,所以直接在线段树合并时打标记即可。于是我们就完成了这一题,时间复杂度为O(nlogn)
以下是本人代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=998244353;
int n,first[300010]={0},tot=0,totp=0,rt[300010]={0},ch[6000010][2]={0};
ll p[6000010],seg[6000010]={0},tag[6000010]={0};
ll pxL,pyL,ans;
struct edge
{
    int v,next;
}e[300010];
struct forsort
{
    int id;
    ll val;
}f[300010];

bool cmp(forsort a,forsort b)
{
    return a.val<b.val;
}

void insert(int a,int b)
{
    e[++tot].v=b;
    e[tot].next=first[a];
    first[a]=tot;
}

ll power(ll a,ll b)
{
    ll s=1,ss=a;
    while(b)
    {
        if (b&1) s=s*ss%mod;
        ss=ss*ss%mod;b>>=1;
    }
    return s;
}

void update(int x,ll p)
{
    seg[x]=seg[x]*p%mod;
    tag[x]=tag[x]*p%mod;
}

void pushdown(int no)
{
    if (tag[no]!=1)
    {
        if (ch[no][0]) update(ch[no][0],tag[no]);
        if (ch[no][1]) update(ch[no][1],tag[no]);
        tag[no]=1;
    }
}

void pushup(int no)
{
    seg[no]=(seg[ch[no][0]]+seg[ch[no][1]])%mod;
}

void seginsert(int &no,int l,int r,int x)
{
    if (!no) no=++totp;
    tag[no]=1;
    if (l==r) {seg[no]=1;return;}
    int mid=(l+r)>>1;
    if (x<=mid) seginsert(ch[no][0],l,mid,x);
    else seginsert(ch[no][1],mid+1,r,x);
    pushup(no);
}

int merge(int x,int y,ll p)
{
    if (!x)
    {
        if (!y) return y;
        pyL=(pyL+seg[y])%mod;
        update(y,(p*pxL%mod+(1-p+mod)*(1-pxL+mod)%mod)%mod);
        return y;
    }
    if (!y)
    {
        pxL=(pxL+seg[x])%mod;
        update(x,(p*pyL%mod+(1-p+mod)*(1-pyL+mod)%mod)%mod);
        return x;
    }
    pushdown(x),pushdown(y);
    ch[x][0]=merge(ch[x][0],ch[y][0],p);
    ch[x][1]=merge(ch[x][1],ch[y][1],p);
    pushup(x);
    return x;
}

void solve(int v)
{
    int lson=0,rson=0;
    if (first[v])
    {
        lson=e[first[v]].v;
        solve(lson);
    }
    if (e[first[v]].next)
    {
        rson=e[e[first[v]].next].v;
        solve(rson);
    }
    if (!lson) return;
    if (!rson) rt[v]=rt[lson];
    else
    {
        pxL=pyL=0;
        rt[v]=merge(rt[lson],rt[rson],p[v]);
    }
}

void query(int no,int l,int r)
{
    if (l==r)
    {
        ans=(ans+(ll)l*f[l].val%mod*seg[no]%mod*seg[no]%mod)%mod;
        return;
    }
    int mid=(l+r)>>1;
    pushdown(no);
    query(ch[no][0],l,mid);
    query(ch[no][1],mid+1,r);
}

int main()
{
    scanf("%d",&n);
    for(int i=1;i<=n;i++)
    {
        int x;
        scanf("%d",&x);
        if (x) insert(x,i);
    }

    tot=0;
    ll inv=power(10000,mod-2);
    for(int i=1;i<=n;i++)
    {
        if (!first[i])
        {
            scanf("%lld",&f[++tot].val);
            f[tot].id=i;
        }
        else
        {
            scanf("%lld",&p[i]); 
            p[i]=p[i]*inv%mod;
        }
    }
    sort(f+1,f+tot+1,cmp);
    for(int i=1;i<=tot;i++)
    {
        rt[f[i].id]=++totp;
        seginsert(rt[f[i].id],1,tot,i);
    } 

    solve(1);
    ans=0;
    query(rt[1],1,tot);
    printf("%lld",ans);

    return 0;
}
posted @ 2018-06-21 09:38  Maxwei_wzj  阅读(101)  评论(0编辑  收藏  举报