BZOJ 4818 [Sdoi2017]序列计数 ——矩阵乘法

发现转移矩阵是一个循环矩阵。

然后循环矩阵乘以循环矩阵还是循环矩阵。

据说还有FFT并且更优的做法。

之后再看吧

#include <map>
#include <cmath>
#include <queue>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
#define F(i,j,k) for (int i=j;i<=k;++i)
#define D(i,j,k) for (int i=j;i>=k;--i)
#define ll long long
#define mp make_pair
#define md 20170408
#define maxn 20000005
 
int pri[maxn],top,p,n,m;
bool vis[maxn];
int cnt[2][101];//0 质数 1 非质数 
 
void init()
{
    memset(vis,false,sizeof vis);
    cnt[0][1]++;cnt[1][1]++;
    F(i,2,m)
    {
        if (!vis[i]) pri[++top]=i,cnt[0][i%p]++;
        else cnt[0][i%p]++,cnt[1][i%p]++;
        for (int j=1;j<=top&&(ll)i*pri[j]<=m;++j)
        {
            vis[i*pri[j]]=true;
            if (i%pri[j]==0) break;
        }
    }
}
 
struct Matrix{
    int x[101][101];
    void init(){memset(x,0,sizeof x);}
    void build1()
    {
        init();
        F(i,0,p-1) F(j,0,p-1)
            (x[i][(i+j)%p]+=cnt[1][j])%=md;
    }
    void builde()
    {
        init();
        F(i,0,p-1) x[i][i]=1;
    }
    void build0()
    {
        init();
        F(i,0,p-1) F(j,0,p-1)
            (x[i][(i+j)%p]+=cnt[0][j])%=md;
    }
    Matrix operator * (Matrix a) {
        Matrix ret;
        ret.init();
        F(i,0,p-1) F(j,0,p-1) F(k,0,p-1)
            (ret.x[i][j]+=x[i][k]*a.x[k][j])%=md;
        return ret;
    }
    Matrix operator ^ (Matrix a){
        Matrix ret;
        ret.init();
        F(j,0,p-1) F(k,0,p-1)
            ret.x[0][j]=((ll)ret.x[0][j]+(ll)x[0][k]*a.x[k][j])%md;
        F(i,1,p-1)
        {
            ret.x[i][0]=ret.x[i-1][p-1];
            F(j,1,p-1) ret.x[i][j]=ret.x[i-1][j-1];
        }
        return ret;
    }
    void print()
    {
        printf("|----------|\n");
        F(i,0,p-1)
        {
            F(j,0,p-1)
                printf("%d ",x[i][j]);
            printf("\n");
        }
        printf("|----------|\n");
    }
}A,B,S,C,D;
 
int main()
{
    scanf("%d%d%d",&n,&m,&p); init();
    S.init();S.x[0][0]=1;
    int b=n,ans=0;
    A.build0(); C.builde();
    while (b)
    {
        if (b&1) C=C^A;
        A=A^A;
        b>>=1;
    }
    ans+=C.x[0][0];
    A.build1(); C.builde(); b=n;
    while (b)
    {
        if (b&1) C=C^A;
        A=A^A;
        b>>=1;
    }
    ans-=C.x[0][0];
    printf("%d\n",(ans%md+md)%md);
}

  

posted @ 2017-04-20 10:57  SfailSth  阅读(188)  评论(0编辑  收藏  举报