【已整理】字符串 AC自动机 总

1.hdu2222

给出若干个模式串,给出一个主串,求模式串在主串中有哪些出现了,输出出现的模式串的个数

#include<stdio.h>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#include<map>
#include<set>
#include<ctime>
using namespace std;
const int MAXN=1000005;
const int MAX_NODE=500005;
const int SIGMA_SIZE=26;

int idx(char c){
    return c-'a';
}

struct Aho{
    int next[MAX_NODE][SIGMA_SIZE];
    int fail[MAX_NODE],val[MAX_NODE];
    int size;
    queue<int> que;
    void init(){
        memset(next[0],0,sizeof(next[0]));
        fail[0]=val[0]=0;
        size=1;
    }
    void insert(char *s){
        int n=strlen(s);
        int now=0;
        for(int i=0;i<n;i++){
            int c=idx(s[i]);
            if(!next[now][c]){
                memset(next[size],0,sizeof(next[size]));
                fail[size]=val[size]=0;
                next[now][c]=size++;
            }
            now=next[now][c];
        }
        val[now]++;
    }
    void build(){
        fail[0]=0;
        for(int i=0;i<SIGMA_SIZE;i++){
            if(!next[0][i])
                next[0][i]=0;
            else{
                fail[next[0][i]]=0;
                que.push(next[0][i]);
            }
        }
        while(!que.empty()){
            int u=que.front();
            que.pop();
            for(int i=0;i<SIGMA_SIZE;i++){
                int &x=next[u][i],y=next[fail[u]][i];
                if(!x)
                    x=y;
                else{
                    fail[next[u][i]]=y;
                    que.push(x);
                }
            }
        }
    }
    int match(char *s){
        int res=0;
        int n=strlen(s);
        int now=0;
        for(int i=0;i<n;i++){
            int c=idx(s[i]);
            now=next[now][c];
            int cur=now;
            while(cur!=0){
                res+=val[cur];
                val[cur]=0;
                cur=fail[cur];
            }
        }
        return res;
    }
}aho;

int t,n;
char s[MAXN];

int main(){
    scanf("%d",&t);
    while(t--){
        aho.init();
        scanf("%d",&n);
        for(int i=0;i<n;i++){
            scanf("%s",s);
            aho.insert(s);
        }
        aho.build();
        scanf("%s",s);
        printf("%d\n",aho.match(s));
    }
    return 0;
}
View Code

模板更新,做了一些优化

2.hdu2896

这题MLE到怀疑人生。。。结果发现了一个神奇的东西,数组如果一开始开了很大,但是程序运行过程中并没有访问到那么多,没有访问到的是不会算到Memory中的,所以是修改了初始化函数后才AC的

#include<stdio.h>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#include<map>
#include<set>
#include<ctime>
using namespace std;
typedef long long ll;
const int MAXN=10005;
const int MAX_NODE=500*205;
const int SIGMA_SIZE=128;

int idx(char c){
    return c;
}

struct Aho{
    int next[MAX_NODE][SIGMA_SIZE];
    int fail[MAX_NODE],val[MAX_NODE];
    int size;
    queue<int> que;
    void init(){
        memset(next[0],0,sizeof(next[0]));
        fail[0]=val[0]=0;
        size=1;
    }
    void insert(char *s,int pp){
        int n=strlen(s);
        int now=0;
        for(int i=0;i<n;i++){
            int c=idx(s[i]);
            if(!next[now][c]){
                memset(next[size],0,sizeof(next[size]));
                fail[size]=val[size]=0;
                next[now][c]=size++;
            }
            now=next[now][c];
        }
        val[now]=pp;
    }
    void build(){
        fail[0]=0;
        for(int i=0;i<SIGMA_SIZE;i++){
            if(!next[0][i])
                next[0][i]=0;
            else{
                fail[next[0][i]]=0;
                que.push(next[0][i]);
            }
        }
        while(!que.empty()){
            int u=que.front();
            que.pop();
            for(int i=0;i<SIGMA_SIZE;i++){
                int &x=next[u][i],y=next[fail[u]][i];
                if(!x)
                    x=y;
                else{
                    fail[next[u][i]]=y;
                    que.push(x);
                }
            }
        }
    }
    set<int> ss;
    bool match(char *s,int web){
        ss.clear();
        int n=strlen(s);
        int now=0;
        for(int i=0;i<n;i++){
            int c=idx(s[i]);
            now=next[now][c];
            int cur=now;
            while(cur!=0){
                if(val[cur])
                    ss.insert(val[cur]);
                cur=fail[cur];
            }
            if(ss.size()>=3)
                break;
        }
        if(ss.size()){
            printf("web %d:",web);
            for(set<int>::iterator it=ss.begin();it!=ss.end();it++)
                printf(" %d",*it);
            printf("\n");
            return true;
        }
        return false;
    }
}aho;

int t,n;
char s[MAXN];

int main(){
    int n,m;
    while(~scanf("%d",&n)){
        aho.init();
        for(int i=1;i<=n;i++){
            scanf("%s",s);
            aho.insert(s,i);
        }
        aho.build();
        scanf("%d",&m);
        int tot=0;
        for(int i=1;i<=m;i++){
            scanf("%s",s);
            if(aho.match(s,i))
                tot++;
        }
        printf("total: %d\n",tot);
    }
    return 0;
}
View Code

3.hdu3065

#include<stdio.h>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#include<map>
#include<set>
#include<ctime>
using namespace std;
typedef long long ll;
const int MAXN=1005;
const int MAX_NODE=50005;
const int SIGMA_SIZE=26;
int counter[MAXN];

int idx(char c){
    return c-'A';
}

struct Aho{
    int next[MAX_NODE][SIGMA_SIZE];
    int fail[MAX_NODE],val[MAX_NODE];
    int size;
    queue<int> que;
    void init(){
        memset(next[0],0,sizeof(next[0]));
        fail[0]=val[0]=0;
        size=1;
    }
    void insert(char *s,int pp){
        int n=strlen(s);
        int now=0;
        for(int i=0;i<n;i++){
            int c=idx(s[i]);
            if(!next[now][c]){
                memset(next[size],0,sizeof(next[size]));
                fail[size]=val[size]=0;
                next[now][c]=size++;
            }
            now=next[now][c];
        }
        val[now]=pp;
    }
    void build(){
        fail[0]=0;
        for(int i=0;i<SIGMA_SIZE;i++){
            if(!next[0][i])
                next[0][i]=0;
            else{
                fail[next[0][i]]=0;
                que.push(next[0][i]);
            }
        }
        while(!que.empty()){
            int u=que.front();
            que.pop();
            for(int i=0;i<SIGMA_SIZE;i++){
                int &x=next[u][i],y=next[fail[u]][i];
                if(!x)
                    x=y;
                else{
                    fail[x]=y;
                    que.push(x);
                }
            }
        }
    }
    void match(char *s){
        int n=strlen(s);
        int now=0;
        for(int i=0;i<n;i++){
            if(s[i]<'A'||s[i]>'Z'){
                now=0;
                continue;
            }
            int c=idx(s[i]);
            now=next[now][c];
            int cur=now;
            while(cur!=0){
                if(val[cur])
                    counter[val[cur]]++;
                cur=fail[cur];
            }
        }
    }
}aho;

int t,n;
char s[1005][55];
char str[2000005];
int main(){
    int n,m;
    while(~scanf("%d",&n)){
        aho.init();
        memset(counter,0,sizeof(counter));
        for(int i=1;i<=n;i++){
            scanf("%s",s[i]);
            aho.insert(s[i],i);
        }
        aho.build();
        scanf("%s",str);
        aho.match(str);
        for(int i=1;i<=n;i++){
            if(counter[i])
                printf("%s: %d\n",s[i],counter[i]);
        }
    }
    return 0;
}
View Code

这题有一个很重要的注意点,给出的主串里面会包含模式串中不含有的字母,即不在SIGMA_SIZE中的字母,此时需要特判一下, 然后直接回到根节点。

前三题都是AC自动机最裸的题型

4.codeforces 86C Genetic engineering

给出m个模式串,再给出一个n,让你构造长度为n的串,使得这个串中的每一个字母都至少包含在某一个模式串中。即对于每一个位置的字母,都能找到一个区间,这个字母在这个区间中,且这个区间形成的串是某一模式串。构造出来的串中,不需要全部模式串都出现。n的范围1e3。且其它的数据量也都较小,很明显可以用DP做,一开始想的是DP[i][j]表示当前构造出了长度为i的串,终止在j号节点的构造方案数。但是很明显会重复- -(我是调试了很久才发现这个问题。。。)

网上搜了别人的题解,都是这么开的状态。dp[i][j][k]表示当前需要构造出长度为i的串,到了j号节点,且还有长度为k的后缀没有匹配。

leftover[x]表示在Trie树中,匹配到x号节点最长能匹配多长的后缀。我们用dp[i][j][k]去更新后面的状态,如果j号节点的下一个节点为x,且leftover[x]》=k+1,则累加更新dp[i][x][0],否则,累加更新dp[i][x][k+1]

5.poj2778 DNA Sequence 

AC自动机的一种经典题型

题意:给你若干个模式串,让你构造长度为n的不包含这些模式串的串,问你能够造出多少个。很久之前在碰到这种类型的题目的时候,我就觉得这很像形式语言与自动机里学过的,应该可以根据模式串建立的Trie树构造出递推式,然后就可以矩阵加速了。具体方法学一下吧。

如上图,我们构建出整个AC自动机(模式串为ACG和C)。我们先求出一个矩阵,这个矩阵的aij表示从i号节点到j号节点有多少种走法。那么我们需要考虑的有哪些节点?显然只有图中的白色节点,因为红色节点已经包含了模式串,所以不符合条件。3和4是显然的,为什么2也是呢?因为2的fail指针指向了4,意思就是2的后缀中包含了4号节点代表的串,所以2号节点的串也不符合条件,所以本题中我们只需要考虑0和1号节点,构建出一个二维的一步转移矩阵,意思是长度为1的符合条件的串有多少种,然后n次方就表示长度为n的符合条件的串有多少种,因为一开始是从0号节点开始的,所以最终的答案就是A^n中的a[0][0]+a[0][1],即从0号节点n步转移到所有合法节点的方案数的和。

#include<stdio.h>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#include<map>
#include<set>
#include<ctime>
using namespace std;
typedef long long ll;
const int MAXN=1005;
const int MAX_NODE=105;
const int SIGMA_SIZE=4;

int idx(char c){
    if(c=='A')  return 0;
    if(c=='T')  return 1;
    if(c=='G')  return 2;
    return 3;
}

struct Aho{
    int next[MAX_NODE][SIGMA_SIZE];
    int fail[MAX_NODE],val[MAX_NODE];
    int size;
    queue<int> que;
    struct Matrix{
        ll mat[105][105];
        Matrix(int flag=0){
            memset(mat,0,sizeof(mat));
            if(flag==1){
                for(int i=0;i<105;i++)
                    mat[i][i]=1;
            }
        }
    };
    void init(){
        memset(next[0],0,sizeof(next[0]));
        fail[0]=val[0]=0;
        size=1;
    }
    void insert(char *s){
        int n=strlen(s);
        int now=0;
        for(int i=0;i<n;i++){
            int c=idx(s[i]);
            if(!next[now][c]){
                memset(next[size],0,sizeof(next[size]));
                fail[size]=val[size]=0;
                next[now][c]=size++;
            }
            now=next[now][c];
        }
        val[now]++;
    }
    Matrix mul(Matrix a,Matrix b,ll mod){
        Matrix ret(0);
        for(int i=0;i<size;i++){
            for(int j=0;j<size;j++){
                ret.mat[i][j]=0;
                for(int k=0;k<size;k++){
                    ret.mat[i][j]+=a.mat[i][k]*b.mat[k][j];
                    ret.mat[i][j]%=mod;
                }
            }
        }
        return ret;
    }
    Matrix pow_M(Matrix a,ll n,ll mod){
        Matrix ret(1);
        Matrix temp=a;
        while(n){
            if(n&1)
                ret=mul(ret,temp,mod);
            temp=mul(temp,temp,mod);
            n>>=1;
        }
        return ret;
    }
    void build(ll n){
        fail[0]=0;
        for(int i=0;i<SIGMA_SIZE;i++){
            int &x=next[0][i];
            if(!x)
                x=0;
            else{
                fail[x]=0;
                que.push(x);
            }
        }
        while(!que.empty()){
            int u=que.front();
            que.pop();
            if(val[fail[u]])
                val[u]=1;
            for(int i=0;i<SIGMA_SIZE;i++){
                int &x=next[u][i],y=next[fail[u]][i];
                if(!x)
                    x=y;
                else{
                    fail[x]=y;
                    que.push(x);
                }
            }
        }
        Matrix a(0);
        for(int i=0;i<size;i++){
            if(val[i])  continue;
            for(int j=0;j<4;j++){
                int v=next[i][j];
                if(!val[v])
                    a.mat[i][v]++;
            }
        }
        ll mod=100000;
        a=pow_M(a,n,mod);
        ll res=0;
        for(int i=0;i<size;i++){
            if(val[i])  continue;
            res=(res+a.mat[0][i])%mod;
        }
        printf("%lld\n",res);
    }
}aho;

char s[25];

int main(){
    int m;
    ll n;
    while(~scanf("%d%lld",&m,&n)){
        aho.init();
        for(int i=1;i<=m;i++){
            scanf("%s",s);
            aho.insert(s);
        }
        aho.build(n);
    }
    return 0;
}
View Code

6.hdu2243 考研路茫茫――单词情结 重要

题意:和上题类似,本题略有加强,让你求包含任意一个模式串的长度不超过n的串有多少种。为了应用上述的方法,我们将问题转化为用总的方案数减去不包含的方案数。但是问题是,我们需要求长度不超过n的串,意思就是A^1+A^2+……+A^n=B,在求B的第0行的和,就是不包含的方案数。这个怎么求?引入一个很重要的公式。

怎么理解?第二列的E E可以求出前一个矩阵第一行的和,所以每乘上一次该矩阵,左上角可以得到A^n,右上角可以得到之前所有矩阵的和。

理解这个东西每次只能乘一次该矩阵,理解之后我们只要用矩阵快速幂计算即可。

我们求出最终的矩阵后,求出第0行的和,不一定就是最终的答案,因为单位矩阵表示的是长度为0的串,所以要根据题意适当地减去1。

其实本题没必要在后面补两个单位矩阵,因为只需要求第0行的和,所以我们只需要在最后加上一列1即可,假设A是n*n的,我们需要将矩阵变成(n+1)*(n+1)的,最后一列全部补1,剩下的最后一行的位置全部补0即可。

本题还有分治的方法,这种求前缀和很好分治,我们将其对分成两半,后一部分提出公共的幂次,就变成了两个相同的子问题,就这样递归分治即可。

#include<stdio.h>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#include<map>
#include<set>
#include<ctime>
using namespace std;
typedef unsigned long long ll;
const int MAXN=1005;
const int MAX_NODE=55;
const int SIGMA_SIZE=26;

int idx(char c){
    return c-'a';
}

ll quick(ll n,ll k){
    ll res=1;
    while(k){
        if(k&1)
            res*=n;
        n*=n;
        k>>=1;
    }
    return res;
}

struct Aho{
    int next[MAX_NODE][SIGMA_SIZE];
    int fail[MAX_NODE],val[MAX_NODE];
    int size;
    queue<int> que;
    struct Matrix{
        ll mat[MAX_NODE][MAX_NODE];
        Matrix(int flag=0){
            memset(mat,0,sizeof(mat));
            if(flag==1){
                for(int i=0;i<MAX_NODE;i++)
                    mat[i][i]=1;
            }
        }
    };
    void init(){
        memset(next[0],0,sizeof(next[0]));
        fail[0]=val[0]=0;
        size=1;
    }
    void insert(char *s){
        int n=strlen(s);
        int now=0;
        for(int i=0;i<n;i++){
            int c=idx(s[i]);
            if(!next[now][c]){
                memset(next[size],0,sizeof(next[size]));
                fail[size]=val[size]=0;
                next[now][c]=size++;
            }
            now=next[now][c];
        }
        val[now]=1;
    }
    Matrix mul(Matrix a,Matrix b,int bound){
        Matrix ret(0);
        for(int i=0;i<bound;i++){
            for(int j=0;j<bound;j++){
                ret.mat[i][j]=0;
                for(int k=0;k<bound;k++){
                    ret.mat[i][j]+=a.mat[i][k]*b.mat[k][j];
                }
            }
        }
        return ret;
    }
    Matrix pow_M(Matrix a,ll n,int bound){
        Matrix ret(1);
        Matrix temp=a;
        while(n){
            if(n&1)
                ret=mul(ret,temp,bound);
            temp=mul(temp,temp,bound);
            n>>=1;
        }
        return ret;
    }
    void build(ll n){
        fail[0]=0;
        for(int i=0;i<SIGMA_SIZE;i++){
            int &x=next[0][i];
            if(!x)
                x=0;
            else{
                fail[x]=0;
                que.push(x);
            }
        }
        while(!que.empty()){
            int u=que.front();
            que.pop();
            if(val[fail[u]])
                val[u]=1;
            for(int i=0;i<SIGMA_SIZE;i++){
                int &x=next[u][i],y=next[fail[u]][i];
                if(!x)
                    x=y;
                else{
                    fail[x]=y;
                    que.push(x);
                }
            }
        }
        Matrix a(0);
        for(int i=0;i<size;i++){
            if(val[i])  continue;
            for(int j=0;j<SIGMA_SIZE;j++){
                int v=next[i][j];
                if(!val[v])
                    a.mat[i][v]++;
            }
        }
        val[size]=0;//很重要,因为size并不是Trie中的节点,所以之前的样例可能使得val[size]为1,需要将其初始化为0
        //另外如果我们一开始就在init里面memset初始化val也行
        //这题要不是样例刚好出到,我可能一万年都找不到这个BUG
        for(int i=0;i<size+1;i++)
            a.mat[i][size]=1;
        a=pow_M(a,n,size+1);
        ll res=0;
        for(int i=0;i<size+1;i++){
            if(val[i])  continue;
            res+=a.mat[0][i];
        }
        res--;
        Matrix b(0);
        b.mat[0][0]=26;b.mat[0][1]=26;
        b.mat[1][0]=0;b.mat[1][1]=1;
        b=pow_M(b,n-1,2);
        ll tt=0;
        tt+=26*b.mat[0][0]+b.mat[0][1];
        printf("%llu\n",tt-res);
    }
}aho;

char s[25];

int main(){
    int m;
    ll n;
    while(~scanf("%d%llu",&m,&n)){
        aho.init();
        for(int i=1;i<=m;i++){
            scanf("%s",s);
            aho.insert(s);
        }
        aho.build(n);
    }
    return 0;
}
View Code

 7.hdu2825 AC自动机状压DP基础题

题意:给你m个模式串,让你构造长度为n的至少包含k个模式串的串,问你有多少个这样的串

m范围只有10,所以考虑状态压缩

dp[i][j][S]表示长度为i,到了j号节点,包含的串为S的串的个数

则转移方法很简单很暴力

for(int i=0;i<=n;i++){
    for(int j=0;j<size;j++){
        for(int S=0;S<(1<<m);S++)
            dp[i][j][S]=0;
    }
}
dp[0][0][0]=1;
for(int i=0;i<n;i++){
    for(int j=0;j<size;j++){
        for(int S=0;S<(1<<m);S++){
            if(dp[i][j][S]){
                for(int z=0;z<SIGMA_SIZE;z++){
                    int newi=i+1;
                    int newj=next[j][z];
                    int newS=S|val[newj];
                    dp[newi][newj][newS]+=dp[i][j][S];
                    dp[newi][newj][newS]%=mod;
                }
            }
        }
    }
}
View Code

对于val数组,有个很好的技巧,一开始我打算在val里面存当前节点代表的模式串是第几个模式串,为了dp的时候方便,其实我们可以直接将val数组保存为当前节点代表的模式串的权值,即(1<<i),i表示当前节点代表的模式串是第i个模式串。

在aho.build()的时候,要注意val数组的继承关系

while(!que.empty()){
    int u=que.front();
    que.pop();
    val[u]|=val[fail[u]];
    for(int i=0;i<SIGMA_SIZE;i++){
        int &x=next[u][i],y=next[fail[u]][i];
        if(!x)
            x=y;
        else{
            fail[x]=y;
            que.push(x);
        }
    }
}
View Code
val[u]|=val[fail[u]];

这个式子表示,如果我们匹配到当前节点,则从该节点开始一直走失配节点,走到的所有节点所代表的字符串都要加到集合里面

对于一个集合,我们预处理出其中包含的字符串个数,方便统计

for(int i=0;i<(1<<10);i++){
    num[i]=0;
    for(int j=0;j<10;j++){
        if(i&(1<<j))
            num[i]++;
    }
}

最后统计的答案是,对于所有字符串个数大于等于k的集合,我们统计AC自动机上每个节点的贡献

int res=0;
for(int S=0;S<(1<<m);S++){
    if(num[S]<k)    continue;
    for(int i=0;i<size;i++)
        res=(res+dp[n][i][S])%mod;
}
printf("%d\n",res);
#include<stdio.h>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#include<map>
#include<set>
#include<ctime>
using namespace std;
const int mod=20090717;
const int MAXN=1005;
const int MAX_NODE=105;
const int SIGMA_SIZE=26;
int num[1025];

int idx(char c){
    return c-'a';
}

struct Aho{
    int next[MAX_NODE][SIGMA_SIZE];
    int fail[MAX_NODE],val[MAX_NODE];
    int dp[26][102][1025];
    int size;
    queue<int> que;
    void init(){
        memset(next[0],0,sizeof(next[0]));
        fail[0]=val[0]=0;
        size=1;
    }
    void insert(char *s,int i){
        int n=strlen(s);
        int now=0;
        for(int i=0;i<n;i++){
            int c=idx(s[i]);
            if(!next[now][c]){
                memset(next[size],0,sizeof(next[size]));
                fail[size]=val[size]=0;
                next[now][c]=size++;
            }
            now=next[now][c];
        }
        val[now]=i;
    }
    void build(){
        fail[0]=0;
        for(int i=0;i<SIGMA_SIZE;i++){
            int &x=next[0][i];
            if(!x)
                x=0;
            else{
                fail[x]=0;
                que.push(x);
            }
        }
        while(!que.empty()){
            int u=que.front();
            que.pop();
            val[u]|=val[fail[u]];
            for(int i=0;i<SIGMA_SIZE;i++){
                int &x=next[u][i],y=next[fail[u]][i];
                if(!x)
                    x=y;
                else{
                    fail[x]=y;
                    que.push(x);
                }
            }
        }
    }
    void dpmaker(int n,int m,int k){
        for(int i=0;i<=n;i++){
            for(int j=0;j<size;j++){
                for(int S=0;S<(1<<m);S++)
                    dp[i][j][S]=0;
            }
        }
        dp[0][0][0]=1;
        for(int i=0;i<n;i++){
            for(int j=0;j<size;j++){
                for(int S=0;S<(1<<m);S++){
                    if(dp[i][j][S]){
                        for(int z=0;z<SIGMA_SIZE;z++){
                            int newi=i+1;
                            int newj=next[j][z];
                            int newS=S|val[newj];
                            dp[newi][newj][newS]+=dp[i][j][S];
                            dp[newi][newj][newS]%=mod;
                        }
                    }
                }
            }
        }
        int res=0;
        for(int S=0;S<(1<<m);S++){
            if(num[S]<k)    continue;
            for(int i=0;i<size;i++)
                res=(res+dp[n][i][S])%mod;
        }
        printf("%d\n",res);
    }
}aho;

char s[15];

int main(){
    for(int i=0;i<(1<<10);i++){
        num[i]=0;
        for(int j=0;j<10;j++){
            if(i&(1<<j))
                num[i]++;
        }
    }
    int n,m,k;
    while(scanf("%d%d%d",&n,&m,&k)){
        if(n==0&&m==0&&k==0)
            break;
        aho.init();
        for(int i=0;i<m;i++){
            scanf("%s",s);
            aho.insert(s,(1<<i));
        }
        aho.build();
        aho.dpmaker(n,m,k);
    }
    return 0;
}
View Code

 8.hdu2296 Ring

给你一些模式串,每个模式串都有一个分值,让你构造一个长度不超过n的串,使其分值最大,一个串的分值等于这个串中出现的模式串的分数之和,重复出现的模式串(可以重叠)的分数也要重复计算

很明显的AC自动机+DP,我们先利用val数组把所有点的分数预处理出来,这样就不用在DP时走失配指针了。

dp[i][j]表示长度为i且走到节点j时的最大分数,转移方式很简单,采用我为人人转移方式,newi=i+1,枚举k从0到SIGMA_SIZE,newj=next[j][k],如果dp[i][j]+val[newj]较大,则更新dp[newi][newj]。

本题对于最优解的条件较多,按照优先顺序如下

1.分数最大

2.长度最短

3.字典序最小

由于串的长度较小,所以可以对于每一个DP状态都记录下来,然后暴力比较就行了。

#include<stdio.h>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#include<map>
#include<set>
#include<ctime>
using namespace std;
const int MAX_NODE=1005;
const int SIGMA_SIZE=26;

int idx(char c){
    return c-'a';
}

char str[55];

struct Aho{
    int next[MAX_NODE][SIGMA_SIZE];
    int fail[MAX_NODE],val[MAX_NODE];
    int dp[55][MAX_NODE];
    char rec[55][MAX_NODE][55];
    int tot[55][MAX_NODE];
    int size;
    queue<int> que;
    void init(){
        memset(next[0],0,sizeof(next[0]));
        fail[0]=val[0]=0;
        size=1;
    }
    void insert(char *s,int i){
        int n=strlen(s);
        int now=0;
        for(int i=0;i<n;i++){
            int c=idx(s[i]);
            if(!next[now][c]){
                memset(next[size],0,sizeof(next[size]));
                fail[size]=val[size]=0;
                next[now][c]=size++;
            }
            now=next[now][c];
        }
        val[now]=i;
    }
    void build(){
        fail[0]=0;
        for(int i=0;i<SIGMA_SIZE;i++){
            int &x=next[0][i];
            if(!x)
                x=0;
            else{
                fail[x]=0;
                que.push(x);
            }
        }
        while(!que.empty()){
            int u=que.front();
            que.pop();
            val[u]+=val[fail[u]];
            for(int i=0;i<SIGMA_SIZE;i++){
                int &x=next[u][i],y=next[fail[u]][i];
                if(!x)
                    x=y;
                else{
                    fail[x]=y;
                    que.push(x);
                }
            }
        }
    }
    void dpmaker(int n){
        for(int i=0;i<=n;i++){
            for(int j=0;j<size;j++){
                dp[i][j]=-1;
                tot[i][j]=0;
                rec[i][j][0]='\0';
            }
        }
        dp[0][0]=0;
        for(int i=0;i<n;i++){
            for(int j=0;j<size;j++){
                if(dp[i][j]!=-1){
                    for(int k=0;k<SIGMA_SIZE;k++){
                        int newi=i+1;
                        int newj=next[j][k];
                        if(dp[newi][newj]<dp[i][j]+val[newj]){
                            dp[newi][newj]=dp[i][j]+val[newj];
                            strcpy(rec[newi][newj],rec[i][j]);
                            int &tt=tot[newi][newj];
                            tt=tot[i][j];
                            rec[newi][newj][tt++]='a'+k;
                            rec[newi][newj][tt]='\0';
                        }
                        else if(dp[newi][newj]==dp[i][j]+val[newj]){
                            strcpy(str,rec[i][j]);
                            int tt=tot[i][j];
                            str[tt++]='a'+k;
                            str[tt]='\0';
                            if(strcmp(str,rec[newi][newj])<0){
                                strcpy(rec[newi][newj],str);
                                tot[newi][newj]=tt;
                            }
                        }
                    }
                }
            }
        }
        int ii=0,jj=0;
        for(int i=1;i<=n;i++){
            for(int j=0;j<size;j++){
                if(dp[ii][jj]<dp[i][j]){
                    ii=i;
                    jj=j;
                }
                else if(dp[ii][jj]==dp[i][j]&&tot[ii][jj]>tot[i][j]){
                    ii=i;
                    jj=j;
                }
                else if(dp[ii][jj]==dp[i][j]&&tot[ii][jj]==tot[i][j]&&strcmp(rec[ii][jj],rec[i][j])>0){
                    ii=i;
                    jj=j;
                }
            }
        }
        printf("%s\n",rec[ii][jj]);
    }
}aho;

char s[105][15];
int num[105];

int main(){
    int t,n,m;
    scanf("%d",&t);
    for(int cas=1;cas<=t;cas++){
        scanf("%d%d",&n,&m);
        aho.init();
        for(int i=0;i<m;i++)
            scanf("%s",s[i]);
        for(int i=0;i<m;i++)
            scanf("%d",num+i);
        for(int i=0;i<m;i++)
            aho.insert(s[i],num[i]);
        aho.build();
        aho.dpmaker(n);
    }
    return 0;
}
View Code

 9.hdu5880

题意:给出若干个模式串,再给一个主串,让你把主串中所有出现的模式串变为*

很明显的AC自动机,但是真的碰到了很多问题。。。

首先是输入的问题。我开的字符数组,读一行的话用cin.getline(s,MAXN);也可以用gets(s)。前者稍微快一点

其次是输出的问题。一个字符一个字符输出的话不要用printf,用putchar

其次是如何处理匹配的问题,首先为了不漏掉任何一个匹配,也为了避免的重复的问题,我们先在构建fail指针时处理出每个点所代表的串的后缀中最长的模式串的长度。这样在用主串匹配时,就不用走fail指针了。其次,我们最好不要一边匹配一边输出,这样会出现问题。。。其实我至今不知道自己的代码哪里出问题了。。。下面先放出出问题的代码

#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<queue>
using namespace std;
const int MAXN=1000005;
const int MAX_NODE=1000005;
const int SIGMA_SIZE=26;
int tmp[MAXN];

int idx(char c){
    if(c>='a'&&c<='z')
        return c-'a';
    if(c>='A'&&c<='Z')
        return c-'A';
    return -1;
}

struct Aho{
    int next[MAX_NODE][SIGMA_SIZE];
    int fail[MAX_NODE],val[MAX_NODE];
    int size;
    queue<int> que;
    void init(){
        memset(next[0],0,sizeof(next[0]));
        fail[0]=val[0]=0;
        size=1;
    }
    void insert(char *s){
        int n=strlen(s);
        int now=0;
        for(int i=0;i<n;i++){
            int c=idx(s[i]);
            if(!next[now][c]){
                memset(next[size],0,sizeof(next[size]));
                fail[size]=val[size]=0;
                next[now][c]=size++;
            }
            now=next[now][c];
        }
        val[now]=n;
    }
    void build(){
        fail[0]=0;
        for(int i=0;i<SIGMA_SIZE;i++){
            int &x=next[0][i];
            if(!x)
                x=0;
            else{
                fail[x]=0;
                que.push(x);
            }
        }
        while(!que.empty()){
            int u=que.front();
            que.pop();
            val[u]=max(val[u],val[fail[u]]);
            for(int i=0;i<SIGMA_SIZE;i++){
                int &x=next[u][i],y=next[fail[u]][i];
                if(!x)
                    x=y;
                else{
                    fail[x]=y;
                    que.push(x);
                }
            }
        }
    }
    void solve(char *s){
        memset(tmp,0,sizeof(tmp));
        int cur=0,n=strlen(s),now=0;
        for(int i=0;i<n;i++){
            int c=idx(s[i]);
            if(c==-1){
                now=0;
                continue;
            }
            now=next[now][c];
            if(val[now]){
                for(int j=cur;j<=i-val[now];j++)
                    putchar(s[j]);
                for(int j=max(cur,i-val[now]+1);j<=i;j++)
                    putchar('*');
                cur=i+1;
            }
        }
        for(int j=cur;j<n;j++)
            putchar(s[j]);
        putchar('\n');
    }
}aho;

char s[MAXN];

int main(){
    int t,n;
    scanf("%d",&t);
    for(int cas=1;cas<=t;cas++){
        aho.init();
        scanf("%d",&n);
        for(int i=0;i<n;i++){
            scanf("%s",s);
            aho.insert(s);
        }
        aho.build();
        getchar();
        gets(s);
        aho.solve(s);
    }
    return 0;
}
View Code

如果不在线处理,那我们肯定要将最长匹配所在的区间标记一下,然后最后再根据标记的区间去输出。有一个很巧妙很重要的方法。

比如我们要标记[a,b]区间,那最暴力的方法肯定是这个区间的每个值设为1,如果觉得太暴力了,那就用线段树,其实没有什么很大的改进。

一个巧妙的办法是,cnt[a]--,cnt[b+1]++,那我们一边处理一边求前缀和,如果当前点的前缀和小于0,说明被标记了,大于等于0,说明没有被标记。

如果标记的区间可能有重叠,该方法也是适用的。

#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<queue>
using namespace std;
const int MAXN=1000005;
const int MAX_NODE=1000005;
const int SIGMA_SIZE=26;
int tmp[MAXN];

int idx(char c){
    if(c>='a'&&c<='z')
        return c-'a';
    if(c>='A'&&c<='Z')
        return c-'A';
    return -1;
}

struct Aho{
    int next[MAX_NODE][SIGMA_SIZE];
    int fail[MAX_NODE],val[MAX_NODE];
    int size;
    queue<int> que;
    void init(){
        memset(next[0],0,sizeof(next[0]));
        fail[0]=val[0]=0;
        size=1;
    }
    void insert(char *s){
        int n=strlen(s);
        int now=0;
        for(int i=0;i<n;i++){
            int c=idx(s[i]);
            if(!next[now][c]){
                memset(next[size],0,sizeof(next[size]));
                fail[size]=val[size]=0;
                next[now][c]=size++;
            }
            now=next[now][c];
        }
        val[now]=n;
    }
    void build(){
        fail[0]=0;
        for(int i=0;i<SIGMA_SIZE;i++){
            int &x=next[0][i];
            if(!x)
                x=0;
            else{
                fail[x]=0;
                que.push(x);
            }
        }
        while(!que.empty()){
            int u=que.front();
            que.pop();
            val[u]=max(val[u],val[fail[u]]);
            for(int i=0;i<SIGMA_SIZE;i++){
                int &x=next[u][i],y=next[fail[u]][i];
                if(!x)
                    x=y;
                else{
                    fail[x]=y;
                    que.push(x);
                }
            }
        }
    }
    void solve(char *s){
        memset(tmp,0,sizeof(tmp));
        int n=strlen(s);
        int now=0;
        for(int i=0;i<n;i++){
            int c=idx(s[i]);
            if(c==-1){
                now=0;
                continue;
            }
            now=next[now][c];
            if(val[now]){
                tmp[i+1]++;
                tmp[i-val[now]+1]--;
            }
        }
        int cnt=0;
        for(int i=0;i<n;i++){
            cnt+=tmp[i];
            if(cnt<0)
                putchar('*');
            else
                putchar(s[i]);
        }
        putchar('\n');
    }
}aho;

char s[MAXN];

int main(){
    int t,n;
    scanf("%d",&t);
    for(int cas=1;cas<=t;cas++){
        aho.init();
        scanf("%d",&n);
        for(int i=0;i<n;i++){
            scanf("%s",s);
            aho.insert(s);
        }
        aho.build();
        getchar();
        cin.getline(s,MAXN);
        aho.solve(s);
    }
    return 0;
}
View Code

 10.zoj3228

题意:给出一个主串,再给出一些询问,每次询问给出一个模式串,问你这个模式串在主串中可重叠的话出现了多少次,不可重叠的话最多出现了多少次

不可重叠的最大匹配数就是从左往右匹配即可

可重叠的匹配数就是裸的AC自动机,不可重叠的匹配数其实也是裸的,只是我们需要记录一个东西,当前发现匹配的这个串上一次出现的位置,如果判断出来是重叠的,就不去管它,如果不重叠,就贡献加一,且更新最后一次出现的位置。

这题让我发现了以前的AC自动机的板子都是错的。。。找匹配数的代码要像本题一样写

#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<queue>
using namespace std;
const int MAXN=100005;
const int MAX_NODE=700005;
const int SIGMA_SIZE=26;
int cnta[MAX_NODE],cntb[MAX_NODE],last[MAX_NODE];
int belong[MAXN];

int idx(char c){
    return c-'a';
}

struct Aho{
    int next[MAX_NODE][SIGMA_SIZE];
    int fail[MAX_NODE],val[MAX_NODE],depth[MAX_NODE];
    int size;
    queue<int> que;
    void init(){
        memset(next[0],0,sizeof(next[0]));
        fail[0]=val[0]=depth[0]=0;
        size=1;
    }
    int insert(char *s){
        int n=strlen(s);
        int now=0;
        for(int i=0;i<n;i++){
            int c=idx(s[i]);
            if(!next[now][c]){
                memset(next[size],0,sizeof(next[size]));
                fail[size]=val[size]=0;
                depth[size]=i+1;
                next[now][c]=size++;
            }
            now=next[now][c];
        }
        val[now]=1;
        return now;
    }
    void build(){
        fail[0]=0;
        for(int i=0;i<SIGMA_SIZE;i++){
            int &x=next[0][i];
            if(!x)
                x=0;
            else{
                fail[x]=0;
                que.push(x);
            }
        }
        while(!que.empty()){
            int u=que.front();
            que.pop();
            for(int i=0;i<SIGMA_SIZE;i++){
                int &x=next[u][i],y=next[fail[u]][i];
                if(!x)
                    x=y;
                else{
                    fail[x]=y;
                    que.push(x);
                }
            }
        }
    }
    void solve(char *s){
        int n=strlen(s);
        int now=0;
        for(int i=0;i<n;i++){
            int c=idx(s[i]);
            now=next[now][c];
            int cur=now;
            while(cur!=0){
                if(val[cur]){
                    cnta[cur]++;
                    if(i-depth[cur]+1>last[cur]){
                        cntb[cur]++;
                        last[cur]=i;
                    }
                }
                cur=fail[cur];
            }
        }
    }
}aho;

char s[MAXN];
char t[MAXN][7];
int op[MAXN];

int main(){
    int n;
    int cas=1;
    while(~scanf("%s",s)){
        aho.init();
        memset(cnta,0,sizeof(cnta));
        memset(cntb,0,sizeof(cntb));
        memset(last,-1,sizeof(last));
        scanf("%d",&n);
        for(int i=0;i<n;i++){
            scanf("%d%s",op+i,t[i]);
            belong[i]=aho.insert(t[i]);
        }
        aho.build();
        aho.solve(s);
        printf("Case %d\n",cas++);
        for(int i=0;i<n;i++){
            if(op[i]==0)
                printf("%d\n",cnta[belong[i]]);
            else
                printf("%d\n",cntb[belong[i]]);
        }
        printf("\n");
    }
    return 0;
}
View Code

 11.hdu2457

题意:串只包含A,T,G,C。给你若干个模式串,一个原串,改变原串若干次,使其不包含任何一个模式串,问你最少改变多少次。

AC自动机+DP。

dp[i][j]表示长度为i,走到节点j的最少改变次数。

本题其实也是一道构造题。和非法串的构造题很像,我们先要处理出危险的节点,这些节点一次都不能到,所以我们dp的时候直接避开这些点。

转移的时候,如果当前边和原串的当前字母相同,则不需要改变,否则改变次数+1。

#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<queue>
using namespace std;
const int MAXN=1005;
const int MAX_NODE=1005;
const int SIGMA_SIZE=4;
const int INF=1e9+7;

int idx(char c){
    if(c=='A')  return 0;
    if(c=='T')  return 1;
    if(c=='G')  return 2;
    return 3;
}

struct Aho{
    int next[MAX_NODE][SIGMA_SIZE];
    int fail[MAX_NODE],val[MAX_NODE];
    int dp[MAXN][MAX_NODE];
    int size;
    queue<int> que;
    void init(){
        memset(next[0],0,sizeof(next[0]));
        fail[0]=val[0]=0;
        size=1;
    }
    void insert(char *s){
        int n=strlen(s);
        int now=0;
        for(int i=0;i<n;i++){
            int c=idx(s[i]);
            if(!next[now][c]){
                memset(next[size],0,sizeof(next[size]));
                fail[size]=val[size]=0;
                next[now][c]=size++;
            }
            now=next[now][c];
        }
        val[now]=1;
    }
    void build(){
        fail[0]=0;
        for(int i=0;i<SIGMA_SIZE;i++){
            int &x=next[0][i];
            if(!x)
                x=0;
            else{
                fail[x]=0;
                que.push(x);
            }
        }
        while(!que.empty()){
            int u=que.front();
            que.pop();
            if(val[fail[u]])
                val[u]=1;
            for(int i=0;i<SIGMA_SIZE;i++){
                int &x=next[u][i],y=next[fail[u]][i];
                if(!x)
                    x=y;
                else{
                    fail[x]=y;
                    que.push(x);
                }
            }
        }
    }
    int solve(char *s){
        int n=strlen(s);
        int now=0;
        for(int i=0;i<=n;i++){
            for(int j=0;j<size;j++)
                dp[i][j]=INF;
        }
        dp[0][0]=0;
        for(int i=0;i<n;i++){
            int c=idx(s[i]);
            for(int j=0;j<size;j++){
                if(val[j])  continue;
                if(dp[i][j]!=INF){
                    for(int k=0;k<SIGMA_SIZE;k++){
                        int newi=i+1;
                        int newj=next[j][k];
                        if(val[newj])   continue;
                        if(k==c)
                            dp[newi][newj]=min(dp[newi][newj],dp[i][j]);
                        else
                            dp[newi][newj]=min(dp[newi][newj],dp[i][j]+1);
                    }
                }
            }
        }
        int res=INF;
        for(int i=0;i<size;i++){
            if(!val[i])
                res=min(res,dp[n][i]);
        }
        if(res==INF)
            res=-1;
        return res;
    }
}aho;

char s[MAXN];

int main(){
    int n;
    int cas=1;
    while(scanf("%d",&n)){
        if(n==0)
            break;
        aho.init();
        for(int i=0;i<n;i++){
            scanf("%s",s);
            aho.insert(s);
        }
        aho.build();
        scanf("%s",s);
        printf("Case %d: %d\n",cas++,aho.solve(s));
    }
    return 0;
}
View Code

 12.hdu6096

题意:给出若干个串,再给出若干个询问,每次给出一个前缀和一个后缀,问你这对前缀和后缀可以和之前给出的多少个串匹配,要求前缀和后缀不重叠

题解:看了别人的题解才会,确实不是很难,这种套路要吃一堑长一智。因为要同时匹配前缀和后缀,所以我们将一开始给出的串都赋值一份贴在后面,且中间要加上一个无关符号‘#’。然后对于每个询问,我们离线处理,对于每队前缀和后缀,我们构造出后缀#前缀这样的串,然后加到AC自动机里,然后用原来的串去匹配就可以了。本题所给的空间非常多,所以可以猜想是用AC自动机、字典树一类的代码。

然后遇到了一些问题。

1.要求前缀和后缀不重叠,只要比较匹配位置的深度和当前匹配的串的长度即可

2.增加了'#'后,idx函数要增加相应语句

3.为什么要增加'#'?不能直接拼在后面吗?这个“#”真的很有用,一方面,避免了非前缀+后缀的匹配,另一方面,也不会重复匹配,因为‘#’只出现一次,也就决定了最多只会匹配一次

4.如何统计答案?一开始我将AC自动机里的val数组设置为了对应的询问,殊不知可能会有若干个询问是相同的前缀和后缀的情况,这样后一个插入的就会覆盖前一个插入的,所以更好的办法时,插入后返回串的节点位置,然后记录每个询问所在节点。每次匹配,对AC自动机里的节点贡献加1,最后输出答案的时候,利用之前记录下的位置,去访问对应位置的统计值即可

5.scanf读string很麻烦,但是题目又有很多串,而总长度又不会超过500000,我们不可能一开始开那么多的字符数组,所以我们只需要开一个字符数组,然后一个一个往后读,记录下每个串的起始位置即可

6.增加了'#',AC自动机的节点数不再是500000,而应该再增加串的个数,所以开成600000比较合适

View Code

 13.hdu6138

题意:给你若干个串,每次询问指定两个串,问你这两个串的最长的公共的且是某一个串的前缀的子串。也是很明显可以用AC自动机做的。

我先把所有的串加入AC自动机,然后每次询问,我分别走一趟这两个串,然后分别给匹配的节点打上标记,因为是前缀,所以只要是经过的点都要打标记,还有很重要的一点,所有点网上走fail指针直到根节点的所有点也都要打上标记,最后只要查一下AC自动机中的所有点,两次都有标记的点就可以用深度去更新答案,答案是最深的那个两次都有标记的节点的深度。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <string>
#include <vector>
#include <map>
#include <set>
#include <queue>
using namespace std;
typedef long long ll;
const int MAXN=200010;
const int MAX_NODE=100050;
const int SIGMA_SIZE=26;

char s[MAXN];
int loc[MAXN];
int tot;

int idx(char c){
    return c-'a';
}

struct Aho{
    int next[MAX_NODE][SIGMA_SIZE];
    int fail[MAX_NODE],val[MAX_NODE];
    int depth[MAX_NODE];
    bool vis[2][MAX_NODE];
    int size;
    queue<int> que;
    void init(){
        memset(next[0],0,sizeof(next[0]));
        val[0]=depth[0]=0;
        size=1;
    }
    void insert(char *s){
        int n=strlen(s);
        int now=0;
        for(int i=0;i<n;i++){
            int c=idx(s[i]);
            int &x=next[now][c];
            if(!x){
                memset(next[size],0,sizeof(next[size]));
                val[size]=0;depth[size]=i+1;
                x=size++;
            }
            now=x;
        }
        val[now]=1;
    }
    void build(){
        fail[0]=0;
        for(int i=0;i<SIGMA_SIZE;i++){
            int &x=next[0][i];
            if(!x)
                x=0;
            else{
                fail[x]=0;
                que.push(x);
            }
        }
        while(!que.empty()){
            int u=que.front();
            que.pop();
            for(int i=0;i<SIGMA_SIZE;i++){
                int &x=next[u][i],y=next[fail[u]][i];
                if(!x)
                    x=y;
                else{
                    fail[x]=y;
                    que.push(x);
                }
            }
        }
    }
    int solve(int n,int m){
        memset(vis[0],false,sizeof(vis[0]));
        memset(vis[1],false,sizeof(vis[1]));
        int lenn=strlen(s+loc[n]),lenm=strlen(s+loc[m]);
        int now=0;
        for(int i=loc[n];i<loc[n]+lenn;i++){
            int c=idx(s[i]);
            now=next[now][c];
            int cur=now;
            while(cur){
                vis[0][cur]=true;
                cur=fail[cur];
            }
        }
        now=0;
        for(int i=loc[m];i<loc[m]+lenm;i++){
            int c=idx(s[i]);
            now=next[now][c];
            int cur=now;
            while(cur){
                vis[1][cur]=true;
                cur=fail[cur];
            }
        }
        int res=0;
        for(int i=0;i<size;i++){
            if(vis[0][i]&&vis[1][i])
                res=max(res,depth[i]);
        }
        return res;
    }
}aho;

int main(){
    int t,n,m,p,q;
    scanf("%d",&t);
    for(int cas=1;cas<=t;cas++){
        aho.init();
        scanf("%d",&n);
        tot=0;
        for(int i=1;i<=n;i++){
            scanf("%s",s+tot);
            loc[i]=tot;
            aho.insert(s+tot);
            tot+=strlen(s+tot)+1;
        }
        aho.build();
        scanf("%d",&m);
        for(int i=0;i<m;i++){
            scanf("%d%d",&p,&q);
            printf("%d\n",aho.solve(p,q));
        }
    }
    return 0;
}
View Code

 14.hdu3341

题意:给你若干个模式串,再给你一个主串,然后让你重排主串,使得重排后包含尽可能多的模式串,重复包含同一个模式串也要重复计算。

肯定是要用DP写,但是不知道怎么设置状态。。。

由于给出的主串长度只有40,且字母只有A、C、T、G,所以考虑状态压缩,由于我们只在乎每个字母出现了多少次,所以我们先统计每个字母出现了多少次,加入A、C、T、G分别出现了A、C、T、G次,那么所有的满足

的a、b、c、d都可以组成一个状态,哈希一下,再反向记录一下就行了。

知道了状态表示法后,DP的部分就非常好写了

一开始对于状态数的估计出现了错误,一共四十个字母,那么状态数最多10*10*10*10,其实应该是11*11*11*11(mdzz。。。)

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <string>
#include <vector>
#include <map>
#include <set>
#include <queue>
using namespace std;
typedef long long ll;
const int MAXN=200010;
const int MAX_NODE=505;
const int SIGMA_SIZE=4;
int num[4];
int hasher[41][41][41][41];
int dp[MAX_NODE][15555];
int tot;
struct Node{
    int num[4];
}node[15555];

int idx(char c){
    if(c=='A')    return 0;
    if(c=='T')    return 1;
    if(c=='G')    return 2;
    return 3;
}

struct Aho{
    int next[MAX_NODE][SIGMA_SIZE];
    int fail[MAX_NODE],val[MAX_NODE];
    int size;
    queue<int> que;
    void init(){
        memset(next[0],0,sizeof(next[0]));
        val[0]=0;
        size=1;
    }
    void insert(char *s){
        int n=strlen(s);
        int now=0;
        for(int i=0;i<n;i++){
            int c=idx(s[i]);
            int &x=next[now][c];
            if(!x){
                memset(next[size],0,sizeof(next[size]));
                val[size]=0;
                x=size++;
            }
            now=x;
        }
        val[now]++;
    }
    void build(){
        fail[0]=0;
        for(int i=0;i<SIGMA_SIZE;i++){
            int &x=next[0][i];
            if(!x)
                x=0;
            else{
                fail[x]=0;
                que.push(x);
            }
        }
        while(!que.empty()){
            int u=que.front();
            que.pop();
            val[u]+=val[fail[u]];
            for(int i=0;i<SIGMA_SIZE;i++){
                int &x=next[u][i],y=next[fail[u]][i];
                if(!x)
                    x=y;
                else{
                    fail[x]=y;
                    que.push(x);
                }
            }
        }
    }
    void solve(){
        for(int i=0;i<size;i++){
            for(int j=0;j<tot;j++)
                dp[i][j]=-1;
        }//2e7
        dp[0][tot-1]=0;
        for(int S=tot-1;S>=1;S--){
            for(int i=0;i<size;i++){
                if(dp[i][S]==-1)    continue;
                for(int j=0;j<SIGMA_SIZE;j++){
                    if(node[S].num[j]){
                        int newi=next[i][j];
                        int w=node[S].num[0];
                        int x=node[S].num[1];
                        int y=node[S].num[2];
                        int z=node[S].num[3];
                        if(j==0)    w--;
                        else if(j==1)    x--;
                        else if(j==2)    y--;
                        else    z--;
                        int newS=hasher[w][x][y][z];
                        dp[newi][newS]=max(dp[newi][newS],dp[i][S]+val[newi]);
                    }
                }
            }
        }
    }
}aho;

char s[15];
char t[55];

int main(){
    int n,m,p,q,cas=1;
    while(scanf("%d",&n)){
        if(n==0)    break;
        aho.init();
        for(int i=0;i<n;i++){
            scanf("%s",s);
            aho.insert(s);
        }
        aho.build();
        scanf("%s",t);
        memset(num,0,sizeof(num));
        int len=strlen(t);
        for(int i=0;i<len;i++)
            num[idx(t[i])]++;
        tot=0;
        for(int i=0;i<=num[0];i++){
            for(int j=0;j<=num[1];j++){
                for(int k=0;k<=num[2];k++){
                    for(int h=0;h<=num[3];h++){
                        hasher[i][j][k][h]=tot;
                        node[tot].num[0]=i;
                        node[tot].num[1]=j;
                        node[tot].num[2]=k;
                        node[tot++].num[3]=h;
                    }
                }
            }
        }
        aho.solve();
        int res=0;
        for(int i=0;i<aho.size;i++){
            res=max(res,dp[i][0]);
        }
        printf("Case %d: %d\n",cas++,res);
    }
    return 0;
}
View Code

15.poj1625

题意:给你若干个模式串,一个字母表,让你构造长度为m的字符串,问你能够构造出多少个不包含任意模式串的串。

很明显的AC自动机,由于m很小,不需要矩阵加速,但是没有让你取MOD,所以得用大数,即将DP数组开成大数

有一个坑点,即对于idx的设置,字符数组要设置成unsigned char,我就是因为用了char,RE了。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <vector>
#include <queue>
#include <stack>
using namespace std;

struct BigInteger{
    int A[25];
    enum{MOD = 10000};
    BigInteger(){memset(A, 0, sizeof(A)); A[0]=1;}
    void set(int x){memset(A, 0, sizeof(A)); A[0]=1; A[1]=x;}
    void print(){
        printf("%d", A[A[0]]);
        for (int i=A[0]-1; i>0; i--){
            if (A[i]==0){printf("0000"); continue;}
            for (int k=10; k*A[i]<MOD; k*=10) printf("0");
            printf("%d", A[i]);
        }
        printf("\n");
    }
    int& operator [] (int p) {return A[p];}
    const int& operator [] (int p) const {return A[p];}
    BigInteger operator + (const BigInteger& B){
        BigInteger C;
        C[0]=max(A[0], B[0]);
        for (int i=1; i<=C[0]; i++)
            C[i]+=A[i]+B[i], C[i+1]+=C[i]/MOD, C[i]%=MOD;
        if (C[C[0]+1] > 0) C[0]++;
        return C;
    }
    BigInteger operator * (const BigInteger& B){
        BigInteger C;
        C[0]=A[0]+B[0];
        for (int i=1; i<=A[0]; i++)
            for (int j=1; j<=B[0]; j++){
                C[i+j-1]+=A[i]*B[j], C[i+j]+=C[i+j-1]/MOD, C[i+j-1]%=MOD;
            }
        if (C[C[0]] == 0) C[0]--;
        return C;
    }
};

const int MAX_NODE=105;
const int SIGMA_SIZE=55;
unsigned char alphabet[55];
int idx[256];

struct Aho{
    int next[MAX_NODE][SIGMA_SIZE];
    int fail[MAX_NODE],val[MAX_NODE];
    int size;
    BigInteger dp[55][105];
    queue<int> que;
    void init(){
        memset(next[0],0,sizeof(next[0]));
        fail[0]=val[0]=0;
        size=1;
    }
    void insert(unsigned char *s){
        int now=0;
        for(int i=0;s[i];i++){
            int c=idx[s[i]];
            if(!next[now][c]){
                memset(next[size],0,sizeof(next[size]));
                fail[size]=val[size]=0;
                next[now][c]=size++;
            }
            now=next[now][c];
        }
        val[now]=1;
    }
    void build(){
        fail[0]=0;
        for(int i=0;i<SIGMA_SIZE;i++){
            if(!next[0][i])
                next[0][i]=0;
            else{
                fail[next[0][i]]=0;
                que.push(next[0][i]);
            }
        }
        while(!que.empty()){
            int u=que.front();
            que.pop();
            if(val[fail[u]])
                val[u]=1;
            for(int i=0;i<SIGMA_SIZE;i++){
                int &x=next[u][i],y=next[fail[u]][i];
                if(!x)
                    x=y;
                else{
                    fail[next[u][i]]=y;
                    que.push(x);
                }
            }
        }
    }
    void solve(int n,int m){
        for(int i=0;i<=n;i++){
            for(int j=0;j<size;j++)
                dp[i][j].set(0);
        }
        dp[0][0].set(1);
        for(int i=0;i<n;i++){
            for(int j=0;j<size;j++){
                if(val[j])  continue;
                for(int k=0;k<m;k++){
                    int newj=next[j][k];
                    if(val[newj])   continue;
                    int newi=i+1;
                    dp[newi][newj]=dp[newi][newj]+dp[i][j];
                }
            }
        }
        BigInteger res;
        for(int i=0;i<size;i++)
            res=res+dp[n][i];
        res.print();
    }
}aho;

unsigned char s[25];

int main(){
    int n,m,p;
    while(~scanf("%d%d%d",&n,&m,&p)){
        aho.init();
        cin>>alphabet;
        for(int i=0;i<n;i++)
            idx[alphabet[i]]=i;
        for(int i=0;i<p;i++){
            scanf("%s",s);
            aho.insert(s);
        }
        aho.build();
        aho.solve(m,n);
    }
    return 0;
}
View Code

16.hdu3247

题意:这题的字符只有0和1。给出n个合法的串和m个非法的串,让你求能够包含所有的合法的串且不包含任何一个非法的串的最短长度

思路一:想到这个思路感觉自己马上要AC了。。。。。由于n只有10,我们对n进行状态压缩。dp[i][S]表示走到AC自动机的节点i,包含合法串的状态为S的最短长度。如果暴力更新,想了一下要好几层循环,且时间复杂度肯定爆了。。。所以我很机智地想到了用类似SPFA的更新方法来做这个DP(本来这题的题设也很像最短路)。写完之后,MLE了。。。。。。。。然后看了下自己开的空间,确实是MLE了,以后一定要检查一下。不然写完才发现太浪费时间了。而且这么更新,每次只能走1步,时间复杂度估计也是不够的。下面先放出这一版MLE的代码。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <vector>
#include <queue>
#include <set>
#include <ctime>
using namespace std;

const int MAX_NODE=605;
const int SIGMA_SIZE=2;
const int INF=1e9+7;

int idx(char c){
    return c-'0';
}

struct Aho{
    int next[MAX_NODE][SIGMA_SIZE];
    int fail[MAX_NODE],val[MAX_NODE];
    int dp[MAX_NODE][1<<10];
    bool vis[MAX_NODE][1<<10];
    int size;
    queue<int> que;
    void init(){
        memset(next[0],0,sizeof(next[0]));
        fail[0]=val[0]=0;
        size=1;
    }
    void insert(char *s,int num){
        int n=strlen(s);
        int now=0;
        for(int i=0;i<n;i++){
            int c=idx(s[i]);
            if(!next[now][c]){
                memset(next[size],0,sizeof(next[size]));
                fail[size]=val[size]=0;
                next[now][c]=size++;
            }
            now=next[now][c];
        }
        val[now]|=num;
    }
    void build(){
        fail[0]=0;
        for(int i=0;i<SIGMA_SIZE;i++){
            if(!next[0][i])
                next[0][i]=0;
            else{
                fail[next[0][i]]=0;
                que.push(next[0][i]);
            }
        }
        while(!que.empty()){
            int u=que.front();
            que.pop();
            if(val[fail[u]]==INF)
                val[u]=INF;
            if(val[fail[u]]>=0)
                val[u]|=val[fail[u]];
            for(int i=0;i<SIGMA_SIZE;i++){
                int &x=next[u][i],y=next[fail[u]][i];
                if(!x)
                    x=y;
                else{
                    fail[next[u][i]]=y;
                    que.push(x);
                }
            }
        }
    }
    queue< pair<int,int> > Q;
    void solve(int n){
        for(int i=0;i<size;i++){
            for(int j=0;j<(1<<n);j++)
                dp[i][j]=INF,vis[i][j]=false;
        }
        dp[0][0]=0;
        vis[0][0]=true;
        Q.push(make_pair(0,0));
        while(!Q.empty()){
            int i=Q.front().first;
            int S=Q.front().second;
            vis[i][S]=false;
            Q.pop();
            for(int j=0;j<SIGMA_SIZE;j++){
                int newi=next[i][j];
                if(val[newi]==INF)
                    continue;
                int newS=S;
                if(val[newi]!=-1)
                    newS|=val[newi];
                if(dp[newi][newS]>dp[i][S]+1){
                    dp[newi][newS]=dp[i][S]+1;
                    if(!vis[newi][newS]){
                        vis[newi][newS]=true;
                        Q.push(make_pair(newi,newS));
                    }
                }
            }
        }
        int res=INF;
        for(int i=0;i<size;i++)
            res=min(res,dp[i][(1<<n)-1]);
        printf("%d\n",res);
    }
}aho;

char s[1005],t[50005];

int main(){
    int n,m;
    while(scanf("%d%d",&n,&m)){
        if(n==0&&m==0)
            break;
        aho.init();
        for(int i=0;i<n;i++){
            scanf("%s",s);
            aho.insert(s,(1<<i));
        }
        for(int i=0;i<m;i++){
            scanf("%s",t);
            aho.insert(t,INF);
        }
        aho.build();
        aho.solve(n);
    }
    return 0;
}
View Code

思路二:既然一步一步地走会TLE和MLE,那么我能否预处理一下,使得每次不是一步一步地走,而是很多步很多步的走,即我只在包含了合法串的节点之间走。答案是可以的。首先我预处理出所有包含合法串的节点共cnt个。AC自动机上的路径长度为1,所以我用cnt次bfs求出这些节点的最短路。

bfs的写法很关键。我发现如果我求出了a和c之间的最短路,但是这条路径中间还有一个包含了合法串的节点b,如果我在DP的时候从a直接走到了c,那么就会遗漏节点b所包含的合法串。所以a到c应该设置为不可达。所以我在bfs的时候,每到达一个包含了合法串的节点,就不再继续bfs了。

这样预处理出所有合法节点之间的最短路后,我就可以直接在这些点之间进行dp了。

dp[i][S]表示走到合法节点i,且包含的合法串的状态为S时的最短长度。此时i的取值范围由cnt决定,题目给出的合法串的个数只有10,那么我们在AC自动机build之后的合法节点个数也不会很多,我试了下,这一维开到50就够了。

为了将初始状态加入SPFA的队列,我从节点0进行一次bfs。更多的细节注释在了代码之中。(感觉这是我目前写过的最难的AC自动机- -)

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <vector>
#include <queue>
#include <set>
#include <ctime>
using namespace std;

const int MAX_NODE=60005;
const int SIGMA_SIZE=2;
const int INF=1e9+7;
const int MAXM=55;

int idx(char c){
    return c-'0';
}

struct Aho{
    int next[MAX_NODE][SIGMA_SIZE];
    int fail[MAX_NODE],val[MAX_NODE],depth[MAX_NODE];
    int dp[MAX_NODE][1<<10];
    bool vis[MAX_NODE][1<<10];

    int key[MAXM],antikey[MAX_NODE],cnt;
    bool used[MAX_NODE];
    int d[MAX_NODE];
    int path[MAXM][MAXM];
    int size;
    queue<int> que;
    void init(){
        memset(next[0],0,sizeof(next[0]));
        fail[0]=val[0]=0;
        size=1;
    }
    void insert(char *s,int num){
        int n=strlen(s);
        int now=0;
        for(int i=0;i<n;i++){
            int c=idx(s[i]);
            if(!next[now][c]){
                memset(next[size],0,sizeof(next[size]));
                fail[size]=val[size]=0;
                next[now][c]=size++;
            }
            now=next[now][c];
        }
        val[now]|=num;
    }
    void build(){
        fail[0]=0;
        for(int i=0;i<SIGMA_SIZE;i++){
            if(!next[0][i])
                next[0][i]=0;
            else{
                fail[next[0][i]]=0;
                que.push(next[0][i]);
            }
        }
        while(!que.empty()){
            int u=que.front();
            que.pop();
            if(val[fail[u]]==-1)
                val[u]=-1;
            if(val[fail[u]]>=0)
                val[u]|=val[fail[u]];
            for(int i=0;i<SIGMA_SIZE;i++){
                int &x=next[u][i],y=next[fail[u]][i];
                if(!x)
                    x=y;
                else{
                    fail[next[u][i]]=y;
                    que.push(x);
                }
            }
        }
    }
    queue< pair<int,int> > Q;
    void bfs(int u){
        memset(used,false,sizeof(used));
        memset(d,-1,sizeof(d));//初始化为-1,如果没有被更新,就表示不可达
        Q.push(make_pair(u,0));//我没有将d[u]设为0,而是保持为-1,是为了防止后面DP的时候在自己和自己之间走,要不然特判一下也行
        used[u]=true;
        while(!Q.empty()){
            int i=Q.front().first;
            int dis=Q.front().second;
            Q.pop();
            for(int k=0;k<SIGMA_SIZE;k++){
                int newi=next[i][k];
                if(used[newi]||val[newi]==-1)//如果bfs到了非法节点,就直接continue
                    continue;//在这里处理了非法节点之后,后面就再也不用处理非法节点了
                d[newi]=dis+1;
                used[newi]=true;
                if(val[newi]==0)//只有当不是包含了合法串的节点时,才可以继续bfs,防止后面DP会跳过合法节点
                    Q.push(make_pair(newi,dis+1));
            }
        }
    }
    void solve(int n){
        cnt=0;
        for(int i=0;i<size;i++){
            if(val[i]>0){
                key[i]=cnt;
                antikey[cnt++]=i;
            }
        }
        for(int i=0;i<cnt;i++){
            bfs(antikey[i]);
            for(int j=0;j<cnt;j++)
                path[i][j]=d[antikey[j]];
        }
        for(int i=0;i<cnt;i++){
            for(int S=0;S<(1<<n);S++)
                dp[i][S]=INF,vis[i][S]=false;
        }
        bfs(0);
        for(int i=0;i<cnt;i++){
            if(d[antikey[i]]!=-1){
                Q.push(make_pair(i,val[antikey[i]]));
                dp[i][val[antikey[i]]]=d[antikey[i]];
                vis[i][val[antikey[i]]]=true;
            }
        }
        while(!Q.empty()){
            int i=Q.front().first;
            int S=Q.front().second;
            Q.pop();
            vis[i][S]=false;
            for(int newi=0;newi<cnt;newi++){
                if(path[i][newi]==-1)    continue;//如果i到newi不可达直接continue
                int newS=S|val[antikey[newi]];
                if(dp[newi][newS]>dp[i][S]+path[i][newi]){
                    dp[newi][newS]=dp[i][S]+path[i][newi];
                    if(!vis[newi][newS]){
                        Q.push(make_pair(newi,newS));
                        vis[newi][newS]=true;
                    }
                }
            }
        }
        int res=INF;
        for(int i=0;i<cnt;i++)
            res=min(res,dp[i][(1<<n)-1]);
        printf("%d\n",res);
    }
}aho;

char s[1005],t[50005];

int main(){
    int n,m;
    while(scanf("%d%d",&n,&m)){
        if(n==0&&m==0)
            break;
        aho.init();
        for(int i=0;i<n;i++){
            scanf("%s",s);
            aho.insert(s,(1<<i));
        }
        for(int i=0;i<m;i++){
            scanf("%s",t);
            aho.insert(t,-1);
        }
        aho.build();
        aho.solve(n);
    }
    return 0;
}
View Code

17.hdu6086【待补】

题意:01串问题。给出n个模式串,让你构造长度为2L的反回文串,使其包含这n个模式串,问你构造方案数。

我构造了前L个字母之后,这个串的后L的字母就确定了。

(1)模式串出现在前面L个字母,我们只需要把所有的模式串加入AC自动机即可

(2)模式串出现了后面L个字母,我们只需要把所有的模式串反向并将字母翻转加入AC自动机即可,这样就转化为了前L个字母的情况

(3)模式串出现在中间,这种情况其实也可以转化成前L个字母的情况,对所有的模式串进行分析,对一个模式串从中间所有位置断开都试一下,前后补齐后看是否符合反回文串的定义,若符合,则将前半部分加入AC自动机

最后,问题就变成了构造长度为L的串,使其包含(1)(2)(3)中所述的所有模式串的方案数。

 

 

 

 

接下来做几道AC自动机+数位DP的题。

先用简单的题目复习一下数位DP

hdu2089 不要62 求所给区间中,不包含4和62的数的个数

#include <bits/stdc++.h>
using namespace std;

typedef long long int ll;
const int MAXN=100005;

int dp[10][10];//dp[i][j]表示长度为i,且末尾数字为j的合法数的个数
int digits[10];//对0~n的这个n进行十进制分解
int tot;

void init(){
    memset(dp,0,sizeof(dp));
    dp[0][0]=1;
    for(int i=1;i<=7;i++){//题目所给范围为7位数
        for(int j=0;j<10;j++){//枚举当前位末尾数字
            if(j==4)    continue;
            for(int k=0;k<10;k++){//枚举上一位末尾数字
                if(!(j==6&&k==2||k==4))
                    dp[i][j]+=dp[i-1][k];
            }
        }
    }
}

int solve(int n){//【小于】n的数中,有多少个数符合条件
    int num=n;
    int res=0;
    tot=1;
    while(num){//十进制分解
        digits[tot++]=num%10;
        num/=10;
    }
    digits[tot]=0;//为了使下面的DP不特判,我将digits[tot]设为一个不影响DP正确结果的值
    for(int i=tot-1;i>=0;i--){//从高位到低位枚举
        for(int j=0;j<digits[i];j++){//枚举小于当前位的数
            if(!(j==2&&digits[i+1]==6||j==4))//只加上合法的情况
                res+=dp[i][j];
        }
        if(digits[i]==4||(digits[i]==2&&digits[i+1]==6))//枚举到不合法情况的时候,直接break
            break;
    }
    return res;
}

int main(){
    int n,m;
    init();
    while(scanf("%d%d",&n,&m)){
        if(n==0&&m==0)
            break;
        printf("%d\n",solve(m+1)-solve(n));
    }
    return 0;
}
View Code

这种写法每次都要现推,非常麻烦,数位DP有一种通用的dfs写法。

1.zoj3494

题意:给出若干个非法01串,再给出一个区间,问你这个区间内有多少个数的BCD码不包含任何一个非法01串,很明显的数位DP,但是非法01串有若干个,所以考虑用AC自动机来做

别人的题解:

转载请注明出处,谢谢http://blog.csdn.net/acm_cxlove/article/details/7854526       by---cxlove 

题目:给出一些模式串,给出一个范围[A,B],求出区间内有多少个数,写成BCD之后,不包含模式串

http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemCode=3494 

经典的AC自动机+数位DP。

好题,将这二者结合在了一起。。。ORZ

容易弄混的是BCD是二进制,而且并非普通的二进制,而我们的数为10进制。

这里就有一个转换,bcd[i][j]表示状态i经过数字j达到的合法状态

另外数位DP也是很经典,感觉数位DP的dfs写法非常好,通用

不过这题注意一下前导0的问题

另外需要注意的是我们求数位DP的话,需要把左区间-1,这里需要高精度-1

别人的代码:

#include<iostream>  
#include<cstdio>  
#include<map>  
#include<cstring>  
#include<cmath>  
#include<vector>  
#include<algorithm>  
#include<set>  
#include<string>  
#include<queue>  
#define inf 1<<30  
#define M 60005  
#define N 10005  
#define maxn 300005  
#define eps 1e-10  
#define zero(a) fabs(a)<eps  
#define Min(a,b) ((a)<(b)?(a):(b))  
#define Max(a,b) ((a)>(b)?(a):(b))  
#define pb(a) push_back(a)  
#define mem(a,b) memset(a,b,sizeof(a))  
#define LL long long  
#define lson step<<1  
#define rson step<<1|1  
#define MOD 1000000009  
using namespace std;  
struct Trie  
{  
    Trie *next[2];  
    Trie *fail;  
    int isword,kind;  
};  
Trie *que[M],s[M];  
int idx;  
char str[25];  
int bcd[2005][10]; //bcd[i][j]表示在结点i,经过一个数字j,到达的结点  
LL dp[205][2005];   //dp[i][j]表示长度为i,位于结点j的个数  
int bit[205],len,n;  
Trie *NewNode()  
{  
    Trie *tmp=&s[idx];  
    mem(tmp->next,NULL);  
    tmp->isword=0;  
    tmp->fail=NULL;  
    tmp->kind=idx++;  
    return tmp;  
}  
void Insert(Trie *root,char *s,int len)  
{  
    Trie *p=root;  
    for(int i=0; i<len; i++)  
    {  
        if(p->next[s[i]-'0']==NULL) p->next[s[i]-'0']=NewNode();  
        p=p->next[s[i]-'0'];  
    }  
    p->isword=1;  
}  
void Bulid_fail(Trie *root)  
{  
    int head=0,tail=0;  
    que[tail++]=root;  
    root->fail=NULL;  
    while(head<tail)  
    {  
        Trie *tmp=que[head++];  
        for(int i=0; i<2; i++)  
        {  
            if(tmp->next[i])  
            {  
                if(tmp==root) tmp->next[i]->fail=root;  
                else  
                {  
                    Trie *p=tmp->fail;  
                    while(p!=NULL)  
                    {  
                        if(p->next[i])  
                        {  
                            tmp->next[i]->fail=p->next[i];  
                            break;  
                        }  
                        p=p->fail;  
                    }  
                    if(p==NULL) tmp->next[i]->fail=root;  
                }  
                if(tmp->next[i]->fail->isword) tmp->next[i]->isword=tmp->next[i]->fail->isword;  
                que[tail++]=tmp->next[i];  
            }  
            else if(tmp==root) tmp->next[i]=root;  
            else tmp->next[i]=tmp->fail->next[i];  
        }  
    }  
}  
//状态当前在状态pre,经过一个数字num之后到达哪个状态  
//如果不合法,返回-1  
int BCD(int pre,int num)  
{  
    if(s[pre].isword) return -1;  
    int cur=pre;  
    for(int i=3;i>=0;i--)  
    {  
        int k=(num>>i)&1;  
        if(s[cur].next[k]->isword) return -1;  
        else cur=s[cur].next[k]->kind;  
    }  
    return cur;  
}  
void Get_next()  
{  
    for(int i=0;i<idx;i++)  
    {  
        for(int j=0;j<10;j++)  
        {  
            bcd[i][j]=BCD(i,j);  
        }  
    }  
}  
//数位DP,长度为len,当前状态为pos,是否有限制,是否有前导0  
LL dfs(int len,int pos,bool limit,bool zero)  
{  
    if(len==0) return 1;  
    if(!limit&&dp[len][pos]!=-1) return dp[len][pos];  
    LL ans=0;  
    //如果之前全为0,但是由于0是不能计算的,所以当前不为最低位  
    if(len>1&&zero)  
    {  
        ans+=dfs(len-1,pos,limit&&bit[len]==0,true);  
        if(ans>=MOD) ans-=MOD;  
    }  
    else  
    {  
        //判断转移是否合法  
        if(bcd[pos][0]!=-1) ans+=dfs(len-1,bcd[pos][0],limit&&bit[len]==0,false);  
        if(ans>=MOD) ans-=MOD;  
    }  
    int up=limit?bit[len]:9;  
    for(int i=1;i<=up;i++)  
    {  
        if(bcd[pos][i]!=-1)  
        {  
            ans+=dfs(len-1,bcd[pos][i],limit&&i==up,false);  
            if(ans>=MOD) ans-=MOD;  
        }  
    }  
    if(!limit&&!zero) dp[len][pos]=ans;  
    return ans;  
}  
LL cal(char *s,int l)  
{  
    mem(dp,-1);  
    for(int i=1;i<=l;i++) bit[l-i+1]=s[i-1]-'0';  
    dfs(l,0,true,true);  
}  
char A[205],B[205];  
//高精度-1,这样会遗留前导0,无所谓了。。。  
void sub(char *s,int len)  
{  
    for(int i=len-1;i>=0;i--)  
    {  
        if(s[i]=='0') s[i]='9';  
        else  
        {  
            s[i]--;  
            break;  
        }  
    }  
}  
int main()  
{  
    int t;  
    scanf("%d",&t);  
    while(t--)  
    {  
        idx=0;  
        Trie *root=NewNode();  
        scanf("%d",&n);  
        for(int i=1; i<=n; i++)  
        {  
            scanf("%s",str);  
            Insert(root,str,strlen(str));  
        }  
        Bulid_fail(root);  
        Get_next();  
        scanf("%s",A);  
        sub(A,strlen(A));  
        LL ans=-cal(A,strlen(A));  
        scanf("%s",B);  
        ans+=cal(B,strlen(B));  
        printf("%lld\n",(ans%MOD+MOD)%MOD);  
    }  
    return 0;  
}
View Code

 

2.codeforces 434C Tachibana Kanade's Tofu

题意:进制为m的数位dp。一旦包含某些连续的数,就要对该数的value加上给定值,如果一个数的value大于给定的阈值k,则该数不合法,问给定的区间内有多少个合法的数

这题写了一半懒得调了,感觉现在没时间学数位DP,所以AC自动机+数位DP就先放一放吧

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <vector>
#include <queue>
#include <set>
#include <ctime>
using namespace std;
typedef long long int ll;
const int MAX_NODE=205;
const int SIGMA_SIZE=25;
const int INF=1e9+7;
const ll mod=1e9+7;

int idx(char c){
    return c-'0';
}

struct Aho{
    int next[MAX_NODE][SIGMA_SIZE];
    int fail[MAX_NODE],val[MAX_NODE];
    int f[205][25][505];//f[i][j]表示长度为i,末尾数字为j,value为k的串的个数
    int dp[205][205][505];//dp[i][j][k]表示长度为i,当前在节点i,value为k的串的个数
    int size;
    queue<int> que;
    void init(){
        memset(next[0],0,sizeof(next[0]));
        fail[0]=val[0]=0;
        size=1;
    }
    void insert(int *num,int tot,int value){
        int now=0;
        for(int i=0;i<tot;i++){//从高位到低位
            int c=num[i];
            if(!next[now][c]){
                memset(next[size],0,sizeof(next[size]));
                fail[size]=val[size]=0;
                next[now][c]=size++;
            }
            now=next[now][c];
        }
        val[now]+=value;
    }
    void build(){
        fail[0]=0;
        for(int i=0;i<SIGMA_SIZE;i++){
            if(!next[0][i])
                next[0][i]=0;
            else{
                fail[next[0][i]]=0;
                que.push(next[0][i]);
            }
        }
        while(!que.empty()){
            int u=que.front();
            que.pop();
            val[u]+=val[fail[u]];
            for(int i=0;i<SIGMA_SIZE;i++){
                int &x=next[u][i],y=next[fail[u]][i];
                if(!x)
                    x=y;
                else{
                    fail[next[u][i]]=y;
                    que.push(x);
                }
            }
        }
    }
    void prework(int n,int m,int k){//n表示需要预处理的长度,m表示进制,k表示value的阈值
        memset(dp,0,sizeof(dp));
        memset(f,0,sizeof(f));
        dp[0][0][0]=1;
        for(int len=0;len<n;len++){
            for(int i=0;i<size;i++){
                for(int j=0;j<=k;j++){
                    for(int x=0;x<m;x++){
                        int newlen=len+1;
                        int newi=i+1;
                        int newj=j+val[newi];
                        if(newj>k)  continue;
                        printf("dp[%d][%d][%d]=%d dp[%d][%d][%d]=%d after add res=%d\n",newlen,newi,newj,dp[newlen][newi][newj],len,i,j,dp[len][i][j],dp[newlen][newi][newj]+dp[len][i][j]);
                        dp[newlen][newi][newj]+=dp[len][i][j];
                        dp[newlen][newi][newj]%=mod;
                        printf("f[%d][%d][%d]=%d dp[%d][%d][%d]=%d after add res=%d\n",newlen,x,newj,f[newlen][x][newj],len,i,j,dp[len][i][j],f[newlen][x][newj]+dp[len][i][j]);
                        f[newlen][x][newj]+=dp[len][i][j];
                        f[newlen][x][newj]%=mod;
                    }
                }
            }
        }
    }
    ll solve(int num[],int n,int k){
        ll res=0;
        for(int i=n-1;i>=0;i--){//从高位到低位
            for(int j=0;j<num[i];j++){
                for(int x=0;x<=k;x++){
                    res+=f[i+1][j][x];
                    res%=mod;
                }
            }
        }
        return res;
    }
}aho;

const int MAXM=205;
int lowbound[MAXM],highbound[MAXM],lowcnt,highcnt;
int num[MAXM];
int tot;

int main(){
    int n,m,k,c;
    aho.init();
    scanf("%d%d%d",&n,&m,&k);
    scanf("%d",&lowcnt);
    for(int i=0;i<lowcnt;i++)
        scanf("%d",lowbound+lowcnt-i-1);
    scanf("%d",&highcnt);
    for(int i=0;i<highcnt;i++)
        scanf("%d",highbound+highcnt-i-1);
    highbound[0]++;
    for(int i=0;i<highcnt;i++){
        if(highbound[i]>=m){
            highbound[i+1]+=highbound[i]/m;
            highbound[i]%=m;
            if(i+1==highcnt)
                highcnt++;
        }
        else
            break;
    }
    for(int i=0;i<n;i++){
        scanf("%d",&tot);
        for(int i=0;i<tot;i++)
            scanf("%d",num+i);
        int loc=0;//处理前导0
        while(loc<tot&&num[loc]==0) loc++;
        if(loc==tot){
            tot=1;
            num[0]=0;
        }
        else if(loc!=0){
            int j=0;
            for(int i=loc;i<tot;i++)
                num[j++]=num[i];
            tot=j;
        }
        scanf("%d",&c);
        aho.insert(num,tot,c);
    }
    aho.build();
    aho.prework(max(lowcnt,highcnt),m,k);
    printf("%lld\n",(aho.solve(highbound,highcnt,k)-aho.solve(lowbound,lowcnt,k)+mod)%mod);
    return 0;
}
View Code

 

posted @ 2017-08-28 14:47  nearlight  阅读(94)  评论(0)    收藏  举报