[NOI2013]矩阵游戏
题目描述
婷婷是个喜欢矩阵的小朋友,有一天她想用电脑生成一个巨大的 \(n\) 行 \(m\) 列的矩阵(你不用担心她如何存储)。她生成的这个矩阵满足一个神奇的性质:若用 \(F[i][j]\) 来表示矩阵中第 \(i\) 行第 \(j\) 列的元素,则 \(F[i][j]\) 满足下面的递推式:
递推式中 \(a,b,c,d\) 都是给定的常数。
现在婷婷想知道 \(F[n][m]\) 的值是多少,请你帮助她。由于最终结果可能很大,你只需要输出 \(F[n][m]\) 除以 \(1,000,000,007\) 的余数。
分析
首先,看到题目已经直接把递推式给出来了,数据范围很大,硬推肯定\(g\)掉,考虑优化递推的过程。
通过观察递推式我们可以发现:\(F[i,j]\)仅由一个先前的状态转移过来,所需要的维度很少,所以考虑矩阵加速递推。
设状态矩阵为 \(\begin{bmatrix} F[i,j] & 1 \end{bmatrix}\) ,则初始矩阵为\(\begin{bmatrix} 1 & 1 \end{bmatrix} \tag{2}\) 。由于 \(F[i,j]=a\times F[i][j-1]+b (j\neq 1)\) ,构造出第一个转移矩阵 \(M_1 = \begin{bmatrix} a & 0 \\ b & 1 \end{bmatrix} \tag{2}\) 。这样,我们就从 \(\begin{bmatrix} F[i,j] & 1 \end{bmatrix} \tag{2}\) 转移到了 \(\begin{bmatrix} aF[i,j]+b & 1 \end{bmatrix} \tag{2}\),也就是从 \(\begin{bmatrix} F[i,j] & 1 \end{bmatrix} \tag{2}\) 转移到了 \(\begin{bmatrix} F[i,j+1] & 1 \end{bmatrix} \tag{2}\)。
这是在同一行中转移的情况,再来看从上一行转移到下一行,根据\(F[i,1]=c\times F[i-1][m]+d (i\neq 1)\),第二个矩阵应为 \(M_2 = \begin{bmatrix} c & 0 \\ d & 1 \end{bmatrix} \tag{2}\)。
有了两个转移矩阵,我们就可以进行快速幂了。从\(F[1,1]\)到\(F[n,m]\),我们共需要在行之间转移\(n-1\)次,\(1\)到\(n-1\)行每一行会向右转移\((m-1)\)次、向下转移\(1\)次;最后第\(n\)行会向右转移\((m-1)\)次。因此,答案应为 \(\begin{bmatrix} 1 & 1 \end{bmatrix} \tag{2} \times ({M_1}^{m-1 } \times M_2)^{n-1} \times M_1^{m-1}\)。
但是还有一个问题。输入的\(m\)和\(n\)最大有\(10^{1000000}\),读入都不好办,更别说快速幂了。这是,我们就要请出我们的欧拉定理——若\(a\)、\(p\)互质,那么对于任意正整数\(b\),有 \(a^b \equiv a^{ b\%\phi(p) } (mod p)\)。由于\(p\)是\(1e9+7\),我们在读入\(n\)和\(m\)的时候边读边对\((p-1)\)取模,就能保证\(n\)和\(m\)都是在我们可以接受的范围内。
代码
#include<bits/stdc++.h>
#define int long long
using namespace std;
long long read()
{
int x=0,f=1;
char c=getchar();
while(c<'0'||c>'9')
{
if(c=='-')f=-1;
c=getchar();
}
while(c>='0'&&c<='9')
{
x=(x<<3)+(x<<1)+(c^48);
c=getchar();
}
return x*f;
}
const int mod=1e9+7;
int n,m,a,b,c,d;
string ns,ms;
struct matrix
{
int mat[10][10],n,m;
matrix(int a,int b)
{
memset(mat,0,sizeof(mat));
n=a,m=b;
}
bool set_i()
{
if(n!=m)return false;
for(int i=1;i<=n;i++)
mat[i][i]=1;
return true;
}
matrix operator *(matrix a)
{
matrix res(n,a.m);
for(int i=1;i<=n;i++)
for(int j=1;j<=a.m;j++)
for(int k=1;k<=m;k++)
res.mat[i][j]=(res.mat[i][j]+mat[i][k]*a.mat[k][j]%mod)%mod;
return res;
}
};
void display(matrix a)
{
for(int i=1;i<=a.n;i++)
{
for(int j=1;j<=a.m;j++)
printf("%d ",a.mat[i][j]);
puts("");
}
}
matrix qpow(matrix a,int b)
{
matrix ans(a.n,a.n);
ans.set_i();
while(b)
{
if(b&1)ans=(ans*a);
a=(a*a);
b>>=1;
}
return ans;
}
signed main()
{
cin>>ns>>ms;
a=read();b=read();c=read();d=read();
for(int i=0;i<ns.length();i++)
{
if(a == 1)n=((n<<3)+(n<<1)+(ns[i]^48))%mod;
else n=((n<<3)+(n<<1)+(ns[i]^48))%(mod-1);
}
for(int i=0;i<ms.length();i++)
{
if(c==1)m=((m<<3)+(m<<1)+(ms[i]^48))%mod;
else m=((m<<3)+(m<<1)+(ms[i]^48))%(mod-1);
}
matrix x(1,2),p1(2,2),p2(2,2);
x.mat[1][1]=x.mat[1][2]=1;
p1.mat[1][1]=a;p1.mat[2][1]=b;p1.mat[2][2]=1;
p2.mat[1][1]=c,p2.mat[1][2]=0,p2.mat[2][1]=d,p2.mat[2][2]=1;
p1=qpow(p1,m-1);
p2=p1*p2;
p2=qpow(p2,n-1);
p1=p2*p1;
x=x*p1;
printf("%lld",x.mat[1][1]%mod);
return 0;
}