CF710F String Set Queries

CF710F 题解

好题一道,记录一下。

题目传送门CF710F

题意简述:每次加字符串,删字符串,维护一个字符串集,求字符串集中的串在模板串的出现次数,强制在线。

数据范围:操作数\(n\le 2\times 10^5\),记字符串总长\(S=\sum|s_i|\le3\times10^5\)

为了叙述方便,默认\(n,S\)同阶。

Solution

方法一

这道题有一个相当神奇的性质,现在才发现感觉错过了一个亿

特殊性质:在这个字符串集中,长度不同的串只有\(O(\sqrt n)\)个。

证明:因为字符串总长\(S=\sum|s_i|\le3\times10^5\),我们贪心构造一下,让长度不同的串最多。

发现当串长为\(1,2,3,\cdots,(t-1),t\)时长度不同的串最多。

用等差数列求和公式,有\(\frac{t(t+1)}{2}=n\),解得发现\(t\)的大小为\(O(\sqrt n)\)

解法:这是一个很好的性质,看到根号,想到暴力做

对于每种长度相同的串,只需每次\(O(1)\)匹配一下在模板串即可。

于是我们需要一个最坏\(O(1)\)查询子串是否出现,\(O(\sqrt n)\)插入/删除的数据结构。

于是考虑字符串hash,按串长分组,再建一个哈希表维护字符串的位置,这样每次统计答案就为\(O(1)\)。插入更好做,直接插入即可。

P.S.这道题卡自然溢出字符串hash,且此做法时间复杂度为\(O(n\sqrt n)\),几乎是卡着过的,实测双模数TLE,所以只能写单模数字符串hash

听xym说CF中unordered_map要重载什么的运算符,于是手写哈希表。

代码如下:(2.56s)

#include<bits/stdc++.h>
#define x first
#define y second
typedef long long ll;
using namespace std;
const int N=3e5+10,base1=131,mod1=1e9+7,base2=1331,mod2=1e9+9;
const int mod=5e3+3;

pair<ll,ll>p[N],f[N];
int n,len,dfn[N],id[N],tot;
char s[N];
struct Map{//手写哈希表
    struct node{pair<ll,ll>Hash;int cnt;};
    vector<node>A[mod];
    void insert(pair<ll,ll>x){
        int t=x.x%mod;
        for(int i=0;i<A[t].size();i++)
            if(A[t][i].Hash==x){A[t][i].cnt++;return;}
        A[t].push_back({x,1});
    }
    void erase(pair<ll,ll>x){
        int t=x.x%mod;
        for(int i=0;i<A[t].size();i++)
            if(A[t][i].Hash==x){
                if(--A[t][i].cnt==0)
                    swap(A[t][i],A[t].back()),A[t].pop_back();
                return;
            }
    }
    int count(pair<ll,ll>x){
        int t=x.x%mod;
        for(int i=0;i<A[t].size();i++)
            if(A[t][i].Hash==x)return A[t][i].cnt;
        return 0;
    }
}mp[780];

int main(){
    p[0]=make_pair(1,1);
    for(int i=1;i<N;i++){//初始化
        p[i].x=p[i-1].x*base1%mod1;
        //p[i].y=p[i-1].y*base2%mod2;
    }
    scanf("%d",&n);
    for(int i=1,op;i<=n;i++){
        scanf("%d%s",&op,s+1);
        len=strlen(s+1);
        if(op==1){
            pair<ll,ll>Hash=make_pair(0,0);
            for(int j=1;j<=len;j++){
                Hash.x=(Hash.x*base1+s[j])%mod1;
                //Hash.y=(Hash.y*base2+s[j])%mod2;
            }
            if(!id[len])id[len]=++tot,dfn[tot]=len;//插入
            mp[id[len]].insert(Hash);
        }else if(op==2){
            pair<ll,ll>Hash=make_pair(0,0);
            for(int j=1;j<=len;j++){
                Hash.x=(Hash.x*base1+s[j])%mod1;
                //Hash.y=(Hash.y*base2+s[j])%mod2;
            }
            mp[id[len]].erase(Hash);//删除
        }else if(op==3){
            f[0]=make_pair(0,0);
            for(int j=1;j<=len;j++){
                f[j].x=(f[j-1].x*base1+s[j])%mod1;
                //f[j].y=(f[j-1].y*base2+s[j])%mod2;
            }
            int ans=0;
            for(int j=1,L;j<=tot;j++)//每个len都查一下
                for(int k=L=dfn[j];k<=len;k++){
                    pair<ll,ll>Hash;
                    Hash.x=(f[k].x+mod1-p[L].x*f[k-L].x%mod1)%mod1;
                    //Hash.y=(f[k].y+mod2-p[L].y*f[k-L].y%mod2)%mod2;
                    ans+=mp[j].count(Hash);//统计答案
                }
            printf("%d\n",ans);
            fflush(stdout);
        }
    }
    return 0;
}

方法二

如果你学过AC自动机且你做过P5357 【模板】AC 自动机(二次加强版),你会发现这道题本质上就是带删减的 P5357。

做过的回忆一下当时是怎么做的。

在 P5357 中,我们建出AC自动机,发现Fail树有一些特殊的性质。

可以画图理解一下,在Fail树中,如果一个串的父节点为该串的一个后缀,它是它的子节点的后缀。

于是我们可以标记一下对于模板串的每个前缀,它们的父节点构成的集合为模板串的所有子串构成的集合。

于是标记一下模板串的子串在Fail树的出现位置(就是匹配跑一遍),每个串对答案的贡献就是在Fail树中以该点为根的子树中的标记点数(就是统计被标记的点数)。

这就是不带删减的 CF710F。

回到本题,考虑怎么删减字符串,AC自动机相当于一个离线做法,对于答案只能离线下来一遍dfs统计。

但是本题强制在线,又该则么做呢?

回想AC自动机,他只能每次重构,考虑建两个AC自动机,答案即为ask(0,s)-ask(1,s)

这时只剩加入操作,但我们还是做不了。

这是,有一个神奇的想法:二进制拆分。相当考验人类智慧。(对于这种只能重构,每次插入的都可做)。

维护每个AC自动机存了多少个串,如果有两个AC自动机存的串相同,就将他们合并。

手玩一下发现,每个AC自动机的大小只能为2的次幂,即\(1, 2,4,8\cdots\)

于是对于每个串,最多被重构\(O(log n)\)次,每次重构的时间复杂度只与\(AC\)自动机的大小有关,总的时间复杂度为\(O(n log n)\)

具体的,插入一个串时,对该串建一个全新的AC自动机,像二进制进位一样向前合并,均摊\(O(logn)\)

查询答案时,对于每个AC自动机,在 P5357 中,我们考虑一个字符集中的串对答案的贡献,在本题中,我们反过来,考虑模板串中的每个点有多少个祖先在字符集中,容易发现这是等价的,在每个AC自动机查询即可。

总的时间复杂度为\(O(n logn)\),因为AC自动机自带26的常数,所以没比字符串hash快多少。

代码如下:(2.45s)

#include<bits/stdc++.h>
using namespace std;
const int N=3e5+10;

char s[N];
int t[N][30],T[N][30],fail[N],cnt[N];//第一个t用来合并,第二个T用来合并后求Trie图
int root[2][N],tot,num[2],n,sum[N],siz[2][N];
void insert(char*str,int root){
    int p=root;
    for(int i=1;str[i];i++){
        int k=str[i]-'a';
        if(!t[p][k])t[p][k]=++tot;
        p=t[p][k];
    }
    cnt[p]++;
}
void build(int root){
    queue<int>q;
    for(int i=0;i<26;i++)
        if(t[root][i]){
            T[root][i]=t[root][i];
            fail[t[root][i]]=root;//放置跳到别的树上
            q.push(t[root][i]);
        }else T[root][i]=root;//注意这里,调试了n年,可以“停”在根节点
    while(q.size()){
        int x=q.front();q.pop();
        for(int i=0;i<26;i++){
            int p=t[x][i];
            if(!p)T[x][i]=T[fail[x]][i];
            else{
                T[x][i]=t[x][i];
                fail[p]=T[fail[x]][i];
                q.push(T[x][i]);
            }
        }
        sum[x]=cnt[x]+sum[fail[x]];//统计祖先在字符串集中的个数
    }
}
int merge(int x,int y){//实际上合并两个trie树就行
    if(!x||!y)return x+y;
    cnt[x]+=cnt[y];
    for(int i=0;i<26;i++)
        t[x][i]=merge(t[x][i],t[y][i]);
    return x;
}
void insert(int op,char*str){
    root[op][++num[op]]=++tot,siz[op][num[op]]=1;
    insert(str,tot);
    while(siz[op][num[op]]==siz[op][num[op]-1]&&num[op]>1){//二进制拆分
        num[op]--,siz[op][num[op]]+=siz[op][num[op]+1];
        root[op][num[op]]=merge(root[op][num[op]],root[op][num[op]+1]);
    }
    build(root[op][num[op]]);
}
int ask(int op,char*str){
    int ans=0;
    for(int i=1;i<=num[op];i++){//每个AC自动机都求一下
        int p=root[op][i];
        for(int j=1;str[j];j++){
            int k=str[j]-'a';
            p=T[p][k];
            ans+=sum[p];
        }
    }
    return ans;
}

int main(){
    scanf("%d",&n);
    for(int i=1,op;i<=n;i++){
        scanf("%d%s",&op,s+1);
        if(op==1||op==2)insert(op-1,s);
        else if(op==3)printf("%d\n",ask(0,s)-ask(1,s));//化删除为插入
        fflush(stdout);
    }
    return 0;
}

总结:

两种方法都非常巧妙,又相当套路,所以记录一下。

方法一中常见的性质,长度不同的字符串只有\(O(\sqrt n)\)

方法二中二进制拆分,每次插入/重构,像二进制进位一样,优化到了均摊\(O(logn)\)

太妙了太妙了,建议反复观看

posted @ 2024-10-04 18:35  lichenyu_ac  阅读(19)  评论(0)    收藏  举报