bzoj1009 [HNOI2008]GT考试——KMP+矩阵快速幂优化DP

题目:https://www.lydsy.com/JudgeOnline/problem.php?id=1009

字符串计数DP问题啊...连题解都看了好多好久才明白,别提自己想出来的蒟蒻我...

首先要设计一个不太好想的状态:f[i][j]表示大串上到第 i 位时有小串前 j 位的后缀,且不包含整个小串的方案数;

也就是如果小串是 12312 , f[5][3] 表示目前大串的情况是 **123... ;

这个状态要从 i 转移到 i+1 ,还需要一个帮助它的数组 a,a[i][j]表示在长度为 i 的后缀后面加一个数字能变成长度为 j 的后缀的方案数;

也就是说,对于 12312,从0到4的 a 数组应该如下:

9 1 0 0 0

8 1 1 0 0

8 1 0 1 0

9 0 0 0 1

8 1 0 0 0

a 数组的定义可以联想到 kmp 算法,事实上它就是通过 kmp 算法的 nxt 数组求得;

于是就可以得到转移方程:f[i][j] = ∑(0<=k<m) f[i-1][k] * a[k][j]

然后发现对于每一步,进行的转移都是相同的;

所以可以用矩阵快速幂来优化,转移矩阵就是 a 数组;

看了好多好多博客才明白...

这篇博客写得很好:https://blog.csdn.net/loi_dqs/article/details/50897662

尤其是代码真的简洁!所以模仿着写了。

代码如下:

#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
typedef long long ll;
int n,m,mod,nxt[25],sum;
char s[25];
struct Matrix{
    int n,m,a[25][25];
    Matrix(int x=0,int y=0):n(x),m(y) {memset(a,0,sizeof a);}
    void init()
    {
        for(int i=0;i<n;i++)a[i][i]=1;//0 ~ n-1
    }
    Matrix operator * (const Matrix &y) const
    {
        Matrix x(n,y.m);
        for(int i=0;i<n;i++)//从0到n-1 
            for(int k=0;k<m;k++)
                for(int j=0;j<y.m;j++)
                    (x.a[i][j]+=(ll)a[i][k]*y.a[k][j]%mod)%=mod;
        return x;
    }
};
void getnxt()
{
    nxt[0]=nxt[1]=0;//第0位有字符,但含义是无匹配 
    for(int i=1;i<m;i++)
    {
//        int k=i;
        int k=nxt[i];
        while(s[i]!=s[k]&&k)k=nxt[k];
        nxt[i+1]=(s[i]==s[k])?k+1:0;//前一位的nxt冒进一位,对应下面从0开始的字符串匹配 
    }
}
Matrix pw(Matrix x,int k)
{
    Matrix ret(x.n,x.m); ret.init();
    for(;k;k>>=1,x=x*x)
        if(k&1)ret=ret*x;
    return ret;
}
int main()
{
    scanf("%d%d%d%s",&n,&m,&mod,&s);
    getnxt();
    Matrix f(m,m);
    for(int i=0;i<m;i++)
        for(int j='0';j<='9';j++)//第i位上填j 
        {
            int k=i;//已经有长度为i的前缀,而k对应字符串上i的后一位 
            while(k&&s[k]!=j)k=nxt[k];//冒进一位的nxt,表示给i后一位进行匹配 
            if(s[k]==j)k++;//匹配到了第k位,也就是有了k+1长度的前缀
            if(k!=m)f.a[i][k]++; 
        }
    Matrix fn=pw(f,n);
    Matrix ans(1,m);
    ans.a[0][0]=1; ans=ans*fn;
    for(int i=0;i<m;i++)
        (sum+=ans.a[0][i])%=mod;
    printf("%d",sum);
    return 0;
}

 

posted @ 2018-07-02 15:45  Zinn  阅读(129)  评论(0编辑  收藏  举报