「ZJOI2017」树状数组(二维线段树)

「ZJOI2017」树状数组(二维线段树)

吉老师的题目真是难想。。。

代码中求的是 \(\sum_{i=l-1}^{r-1}a_i\),而实际求的是 \(\sum_{i=l}^{r}a_i\),所以我们直接判断 \(a_{l-1}\)\(a_r\) 是否相等就行了。

我们用二维线段树,一维存左端点 \(l\),一维存右端点 \(r\),里面存 \(a_l=a_r\) 的概率。

\(a\in [1,l-1],b\in [l,r]\),操作不在 \(b\),概率为 \(1-p\)

\(a\in [l,r],b\in [l,r]\),操作不在 \(a\)\(b\),概率为 \(1-2\times p\)

\(a\in [r+1,n],b\in [l,r]\),操作不在 \(b\),概率为 \(1-p\)

如果左边相等的概率是 \(p\),右边相等的概率是 \(q\),那么总概率是 \(p\times q+(1-p)\times (1-q)\)

我们标记永久化一下就能做到 \(O(n\log^2 n)\)

但是有问题!\(l=1\) 时求的是前缀和等于后缀和的概率!

那么我们多开一棵线段树记录前缀和等于后缀和的概率。

\(a\in [1,l-1]\),没有操作满足,概率为 \(0\)

\(a\in [l,r]\),操作刚好在 \(a\),概率为 \(p\)

\(a\in [r+1,n]\),没有操作满足,概率为 \(0\)

所以仔细一看是一个二合一的题目。。。

\(Code\ Below:\)

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn=100000+10;
const int mod=998244353;
int n,m,T[maxn<<2],tot;

/*
a -> [1, l - 1]  b -> [l, r]        1 - p
a -> [l, r]      b -> [l, r]        1 - 2 * p
a -> [r + 1, n]  b -> [l, r]        1 - p
a -> [0, 0]      b -> [1, l - 1]    0
a -> [0, 0]      b -> [l, r]        p
a -> [0, 0]      b -> [r + 1, n]    0
*/

struct node{
    int ls,rs,val;
    inline void init(){ls=rs=0,val=1;}
}t[maxn*400];

inline int read(){
    register int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
    return (f==1)?x:-x;
}

inline int fpow(int a,int b){
    int ret=1;
    for(;b;b>>=1,a=1ll*a*a%mod)
        if(b&1) ret=1ll*ret*a%mod;
    return ret;
}

inline int merge(int x,int y){
    return (1ll*x*y+1ll*(mod+1ll-x)*(mod+1ll-y))%mod;
}

namespace ST{
    void update(int &x,int L,int R,int C,int l,int r){
        if(!x) x=++tot,t[x].init();
        if(L <= l && r <= R){t[x].val=merge(t[x].val,C);return;}
        int mid=(l+r)>>1;
        if(L <= mid) update(t[x].ls,L,R,C,l,mid);
        if(R > mid) update(t[x].rs,L,R,C,mid+1,r);
    }
    int query(int x,int l,int r,int k){
        if(!x) return 1;
        if(l == r) return t[x].val;
        int mid=(l+r)>>1;
        if(k <= mid) return merge(t[x].val,query(t[x].ls,l,mid,k));
        else return merge(t[x].val,query(t[x].rs,mid+1,r,k));
    }
}

#define lson (rt<<1)
#define rson (rt<<1|1)

void update(int L,int R,int x,int y,int v,int l,int r,int rt){
    if(L <= l && r <= R){ST::update(T[rt],x,y,v,0,n);return;}
    int mid=(l+r)>>1;
    if(L <= mid) update(L,R,x,y,v,l,mid,lson);
    if(R > mid) update(L,R,x,y,v,mid+1,r,rson);
}

int query(int x,int y,int l,int r,int rt){
    if(l == r) return ST::query(T[rt],0,n,y);
    int mid=(l+r)>>1;
    if(x <= mid) return merge(ST::query(T[rt],0,n,y),query(x,y,l,mid,lson));
    else return merge(ST::query(T[rt],0,n,y),query(x,y,mid+1,r,rson));
}

int main()
{
    n=read(),m=read();
    int op,l,r,p;
    while(m--){
        op=read(),l=read(),r=read();
        if(op==1){
            p=fpow(r-l+1,mod-2);
            update(l,r,l,r,(mod+1-2*p%mod)%mod,0,n,1);
            update(0,0,l,r,p,0,n,1);
            if(l>1){
                update(1,l-1,l,r,(mod+1-p)%mod,0,n,1);
                update(0,0,1,l-1,0,0,n,1);
            }
            if(r<n){
                update(l,r,r+1,n,(mod+1-p)%mod,0,n,1);
                update(0,0,r+1,n,0,0,n,1);
            }
        }
        if(op==2) printf("%d\n",query(l-1,r,0,n,1));
    }
    return 0;
}
posted @ 2019-03-16 20:02 Owen_codeisking 阅读(...) 评论(...) 编辑 收藏