bzoj1009: [HNOI2008]GT考试

kmp+矩阵乘法。

好久以前做过,但我今天居然死活看不懂以前的程序,再写一发。

用f[i][j]表示前i个准考证号匹配了前j个不吉利数字的方案数。

tmp[i][j]表示匹配了前i个不吉利数字以后,增加一个字符可以匹配前j个不吉利数字的方案数。

我们可以枚举(i+1)位的数字,并用kmp求得的next数组进行转移,就可以求出tmp数组。

很显然有(就是我完全不知道为什么。。)

f[i][j]=f[i-1][0]*tmp[0][j]+f[i-1][1]*tmp[1][j]+……+f[i-1][m-1]*tmp[m-1][j]。

tmp[i][j]i可以转移到j(m是不合法的,所以这里没有m)。

而tmp在每次递推时都是相同的,而n又很大。

所以可以用矩阵快速幂在log n的时间求出结果。

初始条件f[0][0]=1。

答案表达式很长,请见代码。

#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
const int maxn = 25;

int n,m,mod;
char s[maxn];
int next[maxn];

struct Matrix {
    int a[maxn][maxn];
    
    int* operator [] (int x) {
        return a[x];
    }
    
    Matrix operator * (Matrix b) {
        Matrix res;
        for(int i=0;i<m;i++)
        for(int j=0;j<m;j++)
        for(int k=0;k<m;k++)
            res[i][k]=(res[i][k]+a[i][j]*b[j][k])%mod;
        return res;
    }
    
    Matrix operator ^ (int e) {
        Matrix res,tmp=*this;
        res.init();
        while(e) {
            if(e&1) res=res*tmp;
            tmp=tmp*tmp;
            e>>=1;
        }
        return res;
    }
    
    void init() {
        memset(a,0,sizeof(a));    
        for(int i=0;i<m;i++) a[i][i]=1;
    }
    
    void debug() {
        for(int i=0;i<m;i++) {
            for(int j=0;j<m;j++) printf("%d ",a[i][j]);
            printf("\n");    
        }
    }
    
    Matrix() {
        memset(a,0,sizeof(a));    
    }
}tmp,f,res;

void kmp() {
    scanf("%s",s+1);
    int j=0;
    for(int i=2;i<=m;i++) {
        while(j&&s[i]!=s[j+1]) j=next[j];
        if(s[j+1]==s[i]) j++;
        next[i]=j;
    }
}

void build() {
    scanf("%d%d%d",&n,&m,&mod);
    kmp();
    for(int i=0;i<m;i++) 
    for(int k=0;k<=9;k++) {
        if((s[i+1]-'0')==k) {
            if(i+1!=m) tmp[i][i+1]=1;
        }
        else {
            int j=next[i];
            while(j&&(s[j+1]-'0')!=k) j=next[j];
            if((s[j+1]-'0')==k) j++;     
            tmp[i][j]++;
        }
    }
}

void solve() {
    f[0][0]=1; //f[0][1]=1;
    res=f*(tmp^(n));
    int ans=0;
    for(int i=0;i<m;i++) ans=(ans+res[0][i])%mod;
    printf("%d\n",ans);
}

int main() {
    build();
    solve();
    return 0;
}
posted @ 2016-07-16 22:06  invoid  阅读(211)  评论(0编辑  收藏  举报