CF1117D Magic Gems

原题传送门

题目大意:xht37有很多魔法宝石。每颗魔法宝石可以分解成\(m\)颗普通宝石,魔法宝石和普通宝石都占据\(1\)体积的空间,但普通宝石不能再被分解。

xht37想要使一些魔法宝石分解,使得所有宝石占据的空间恰好为\(n\)单位体积。显然,一个魔法宝石分解后会占据\(m\)体积空间,不分解的魔法宝石仍占据\(1\)体积空间。

现在xht37想要求出有多少种分解方案,可以让最后得到的宝石恰好占据\(n\)单位体积。两种分解方案不同当且仅当分解的魔法宝石数量不同,或者是所用的宝石的编号不同。

数据范围:\(1 \le n\le 10^{18},2\le m\le 100\)

思路:显然易见的\(dp\)做法。用\(dp[i]\)表示占据\(i\)个体积空间时的分解方案,可以写出转移方程为:

\[\begin{cases} dp[i]=dp[i-1]\qquad i\lt m \\ dp[i]=dp[i-1]+dp[i-m]\qquad i\ge m \end{cases} \]

考虑到本题的\(n\le 10^{18}\),线性递推必然超时,因此我们考虑用矩阵快速幂优化,把复杂度降低到\(log\)级别。

我们考虑这样两个向量:

\[F(n)=\begin{bmatrix} dp_n& dp_{n-1}& ···& dp_{n-m+1}\end{bmatrix} \]

\[F(n-1)=\begin{bmatrix} dp_{n-1}& dp_{n-2}& ···& dp_{n-m}\end{bmatrix} \]

如何找到一个矩阵\(A\),使得\(F(n)=F(n-1)*A\)

得到\(A\)之后,我们把这个递推柿子一直写下去就可以得到:\(F(n)=F(m-1)*A^{n-m+1}\)

由于\(F(m-1)=\begin{bmatrix} dp_{m-1}& dp_{m-2}& ···& dp_{0}\end{bmatrix}\) 由于\(dp_0=1\),通过转移方程可得:\(F(m-1)=\begin{bmatrix} 1& 1& ···& 1& 1\end{bmatrix}\)为一个已知的全\(1\)向量

那我们通过矩阵快速幂求出\(A^{n-m+1}\)之后就可以得到\(F(n)\),而它的第一项就是所要求的答案\(dp_n\)

下面我们讲一下如何求出矩阵\(A\),我们由转移方程可以得到:

\[\begin{cases} dp[n]=1*dp[n-1]+1*dp[n-m]\\ dp[n-1]=1*dp[n-1]\\ \qquad\qquad······\\ dp[n-m]=1*d[n-m] \end{cases} \]

显然这个\(m\times m\)的矩阵\(A\)的各行各列值即为上面\(m\)个柿子的系数,得:

\[A=\begin{bmatrix} 1& 0& 0& ··· &0 &0 &1 \\ 1& 0& 0& ··· &0 &0 &0 \\ 0& 1& 0& ··· &0 &0 &0 \\ 0& 0& 1& ··· &0 &0 &0 \\ ···&···&···&···&···&···&···\\ 0& 0& 0& ··· &0 &0 &1 \\ \end{bmatrix} \]

至此,本题就可以在\(O(m^3logn)\)的复杂度过了

PS:注意\(n<m\)时,直接输出\(1\)即可,不单独讨论的话,会超时......

Code:

#include <bits/stdc++.h>
using namespace std;
const int N=1e2+10,mod=1e9+7;
typedef long long ll;
ll n,m;
struct Mat{
    ll mat[N][N];
    Mat() {memset(mat,0,sizeof(mat));}
    Mat operator*(const Mat &b)const {
        Mat res;
        for(int i=0;i<100;i++){
            for(int j=0;j<100;j++){
                for(int k=0;k<100;k++){
                    res.mat[i][j]+=mat[i][k]*b.mat[k][j]%mod;
                    res.mat[i][j]%=mod;
                }
            }
        }
        return res;
    }
}base,ans;
void mat_power(ll k){
    for(;k;k>>=1){
        if(k&1) ans=ans*base;
        base=base*base;
    }
    return ;
}
int main(){
    scanf("%lld%lld",&n,&m);
    if(n<m) cout<<"1"<<endl;
    else{
        for(int i=0;i<m;i++) ans.mat[0][i]=1;
        base.mat[0][0]=base.mat[0][m-1]=1; 
        for(int i=1;i<m;i++) base.mat[i][i-1]=1;
        mat_power(n-m+1);
        printf("%lld",ans.mat[0][0]);
    }
    return 0;
}
posted @ 2021-12-14 16:44  Wraith-Fiee  阅读(53)  评论(0)    收藏  举报