P5287 [HNOI2019]JOJO border理论 主席树

题意:

戳这里

分析:

  • 暴力

直接KMP,复杂度O(\(n^2\))

  • 正解

首先因为不强制在线,我们可以建出操作树,然后DFS解决操作2的问题,然后我们考虑操作1怎么做,由于每一次暴力KMP的复杂度过高,所以我们要优化KMP。

我们把压缩的字符看成一个二元组,对于一个前缀,如果他是border,他一定满足除了第一段和最后一段以外,中间所有的二元组都和后缀相同。所以我们考虑加入一个二元组 (x,c) 以后答案会发生那些变化,会不断跳nxt,直到找到了一个前缀满足其下一个二元组是 (x',c) ,此时我们会增加 min(x,x') 个border,他们的最后一段是一组公差为1的等差数列,对于剩下的串 x-min(x,x') 接着跳nxt

对于跳的途中,我们若是找到了满足x'=x的二元组 (x',c) ,那将当前二元组的nxt指向找到的这个,否则指向0,因为对于 x'>x 的情况由于任意两个连续二元组的字符不相同,所以必定无法匹配上

但是这样的优化理论上还是过不了这道题,因为暴力跳 nxt 会被卡成单次 O(n) 的,所以总的复杂度还是 O(\(n^2\)) ,虽然实际上真的过了

我们考虑建出一个 KMP自动机(借用网上巨佬的说法),我们对于每一个位置维护一下rt[] 数组表示接一个二元组字符为 i 时会跳到哪里,但是由于二元组的长度不同跳到的 nxt 还是不同,所以我们对于每一个 rt[] 都建一颗主席树,主席树的下标是后接二元组对应的长度,然后我们把统计的答案拆开,每一个border拆成跳nxt得到的前缀的长度+延伸的二元组 (x',c) 中匹配到的长度,我们记 len=min(x,x') ,也就是说贡献拆开后的第二部分是一个长度为 len 的等差数列,第一部分我们就放到主席树上面直接区间查询出来

对于剩下的 x-len 个后缀,我们可能没有办法直接找到一个二元组满足条件的二元组 (x,c) 所以我们需要特判一下,第一个二元组 (\(x_1,c_1\)) 能不能形成一个符合要求的border,能形成的充要条件就是:\(c_1=c\)\(x_1>x\)

统计完答案之后我们要球盖主席树上的信息,我们把 rt[c] 对应的主席树上 [1,x] 的区间的信息都改成当前的前缀长度,这样会更优,然后继续遍历,将跳到的 nxt 节点的 rt[] 数组的信息, 传给下一个节点

代码:

#include<bits/stdc++.h>
#define pii pair<int,int>
#define mk(x,y) make_pair(x,y)
#define pb push_back
#define fir first
#define sec second
#define inl inline
#define reg register

using namespace std;

namespace zzc
{
    inl int read()
    {
        int x=0,f=1;char ch=getchar();
        while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
        while(isdigit(ch)){x=x*10+ch-48;ch=getchar();}
        return x*f;
    }

    const int maxn = 1e5+5;
    const int maxm = 6.4e6+5;
    const int mod = 998244353;
    const int bj = 10005;
    pii val[maxn];
    vector<int> to[maxn];
    int n,tot,top,tim;
    int id[maxn],ans[maxn],lc[maxm],rc[maxm],sum[maxm],nxt[maxm],tag[maxm],st[maxn],a[maxn],b[maxn],mx[maxn][26],rt[maxn][26];
    
    inl void new_node(int &rt) 
    {
        int tmp=++tot;
        sum[tmp]=sum[rt];
        lc[tmp]=lc[rt];
        rc[tmp]=rc[rt];
        nxt[tmp]=nxt[rt];
        tag[tmp]=tag[rt];
        rt=tmp;
    }

    inl void pushup(int rt)
    {
        sum[rt]=(sum[lc[rt]]+sum[rc[rt]])%mod;
    }
    
    inl void form(int rt,int s,int len)
    {
        sum[rt]=1ll*s*len%mod;
        tag[rt]=s;
    }

    inl void pushdown(int rt,int l,int r)
    {
        if(!tag[rt]) return ;
        int mid=(l+r)>>1;
        new_node(lc[rt]);form(lc[rt],tag[rt],mid-l+1);
        new_node(rc[rt]);form(rc[rt],tag[rt],r-mid);
        tag[rt]=0;
    }

    void modify(int &rt,int l,int r,int pos,int st,int k)
    {
        new_node(rt);
        if(r<pos) return form(rt,st,r-l+1);
        if(l==r)
        {
            form(rt,st,1);
            nxt[rt]=k;
            return ;
        }
        int mid=(l+r)>>1;
        pushdown(rt,l,r);
        modify(lc[rt],l,mid,pos,st,k);
        if(pos>mid) modify(rc[rt],mid+1,r,pos,st,k);
        pushup(rt);
    }

    void query(int &rt,int l,int r,int pos,int &res,int &k)
    {
        if(r<pos) return res=(res+sum[rt])%mod,void();
        if(l==r)
        {
            res=(res+sum[rt])%mod;
            k=nxt[rt];
            return ;
        }
        int mid=(l+r)>>1;
        pushdown(rt,l,r);
        query(lc[rt],l,mid,pos,res,k);
        if(pos>mid) query(rc[rt],mid+1,r,pos,res,k);
    }

    inl int getsum(int x)
    {
        return (1ll*(x+1)*x/2)%mod;
    }

    void dfs(int u)
    {
        top++;
        int x=val[u].sec,y=val[u].fir,_nxt=0;
        a[top]=x;b[top]=b[top-1]+y;
        if(top==1) ans[u]=getsum(y-1);
        else
        {
            ans[u]=(ans[u]+getsum(min(mx[top][x],y)))%mod;
            query(rt[top][x],1,bj,y,ans[u],_nxt);
            if(!_nxt&&a[1]==x&&b[1]<y) _nxt=1,ans[u]=(ans[u]+1ll*b[1]*max(0,y-mx[top][x])%mod)%mod;
        }
        mx[top][x]=max(mx[top][x],y);
        modify(rt[top][x],1,bj,y,b[top-1],top);
        for(auto v:to[u])
        {
            memcpy(mx[top+1],mx[_nxt+1],sizeof(mx[top+1]));
            memcpy(rt[top+1],rt[_nxt+1],sizeof(rt[top+1]));
            ans[v]=ans[u];
            dfs(v);
        }
        top--;
    }

    void work()
    {
        int opt,x;
        char ch[5];
        n=read();
        for(reg int i=1;i<=n;i++)
        {
            opt=read();x=read();
            if(opt==1)
            {
                scanf("%s",ch+1);
                val[++tim]=mk(x,ch[1]-'a');
                id[i]=tim;
                to[id[i-1]].pb(id[i]);
            }
            else id[i]=id[x];
        }
        for(auto v:to[0])
        {
            tot=0;
            memset(rt[1],0,sizeof(rt[1]));
            memset(mx[1],0,sizeof(mx[1]));
            dfs(v);
        }
        for(reg int i=1;i<=n;i++) printf("%d\n",ans[id[i]]);
    }


}

int main()
{
    zzc::work();
    return 0;
}
posted @ 2021-02-17 14:41  youth518  阅读(115)  评论(0)    收藏  举报