矩阵乘法学习笔记

Martix,即矩阵,是线性代数中的一个重要内容。

定义(摘自oi-wiki):

对于矩阵 \(A\) ,主对角线是指 \(A[ i ][ i ]\) 的元素。

一般用 \(I\) 来表示单位矩阵,就是主对角线上为 1,其余位置为 0。

运算:

加减法是逐个元素进行的

我们规定:矩阵乘法只有在第一个矩阵的行数和第二个矩阵的列数相同时才有意义;

\(A\) 为一个 \(P * M\) 的矩阵,\(B\) 为一个 \(M * Q\)的矩阵

\(C[ i ][ j ] = \sum_{k=1}^{M}{A[i][k]*B[k][j]}\)

即逐个相乘再相加,得到一个 \(P*Q\)的矩阵\(C\)

矩阵乘法满足结合律但不满足交换律

快速幂

与单个数字的快速幂相同,只需要把乘号重载即可

struct Matrix{
    long long a[200][200];
    int lenx,leny;//行数,列数
    Matrix(){
        memset(a,0,sizeof(a));
    }
    Matrix friend operator*(Matrix a,Matrix b){
        Matrix res;
        for(long long i=1;i<=a.lenx;++i){
            for(long long j=1;j<=b.leny;++j){
                for(long long k=1;k<=a.leny;++k){
                    res.a[i][j]=(res.a[i][j]+a.a[i][k]*b.a[k][j]);
                }
            }
        }
        res.lenx=a.lenx,res.leny=b.leny;
        return res;
    }
}
                              
void qpow(int b) {
  while (b) {
    if (b & 1) ans = ans * base;
    base = base * base;
    b >>= 1;
  }
}

应用

luoguP1962

斐波那契数列有递推公式
\(F[i]=F[i-1]+F[i-2]\),可以在O(n)的复杂度下求出所有项,但是该题\(1<=n<=2^{63}\)递推显然会TLE。由于只要求求出第\(n\)项,所以没必要把所有的项都求出来,此时可以用矩阵乘法来加速

定义一个\(1*2\)的矩阵\(A\)
\(\begin{bmatrix} F[n-1] & F[n-2] \end{bmatrix}\),一个\(1*2\)的矩阵\(B\)\(\begin{bmatrix} F[n] & F[n-1] \end{bmatrix}\),尝试用求出矩阵\(C\),使得\(A*C=B\)

假设已经求出了\(C\),初始矩阵为\(\begin{bmatrix} 1 & 1 \end{bmatrix}\),对应斐波那契数列的第一项和第二项,\(\begin{bmatrix} 1 & 1 \end{bmatrix} * C\)对应第二项和第三项,\(\begin{bmatrix} 1 & 1 \end{bmatrix} * C * C\)对应第三项和第四项...以此类推,初始矩阵乘上\(n-2\)\(C\)就可以求出第n项,由于矩阵乘法满足结合律,先计算\(n-2\)\(C\)的乘积,即快速幂

接下来考虑怎么求出\(C\)

因为 \(F[n] = F[n-1] + F[n-2]\),所以 \(C\) 矩阵第一列应该是
\(\begin{bmatrix} 1 \\ 1 \end{bmatrix}\) ,这样在进行矩阵乘法运算的时候才能令 \(F[n-1]\)\(F[n-2]\) 相加,从而得出 \(F[n]\)。同理,为了得出 ,矩阵 \(C\) 的第二列应该为 \(\begin{bmatrix} 1 \\ 0 \end{bmatrix}\)

摘自oi-wiki
(就是懒得自己写)

Code

#include<bits/stdc++.h>
using namespace std;

const long long mod=1e9+7;

struct Matrix{
    long long a[5][5];
    Matrix(){
        memset(a,0,sizeof(a));
    }
    Matrix operator*(Matrix &b)const{
        Matrix res;
        for(long long i=1;i<=2;++i){
            for(long long j=1;j<=2;++j){
                for(long long k=1;k<=2;++k){
                    res.a[i][j]=(res.a[i][j]+a[i][k]*b.a[k][j])%mod;
                }
            }
        }
        return res;
    }
}base,ans;
inline long long read(){
    long long x=0,f=1;char c=getchar();
    while(c<'0'||c>'9'){
        if(c=='-')f=-1;c=getchar();
    }
    while(c>='0'&&c<='9'){
        x=(x<<1)+(x<<3)+(c^'0');c=getchar();
    }
    return f*x;
}
void qpow(long long x){
    while(x){
        if(x&1){
           ans=ans*base;            
        }
        base=base*base;x>>=1;
    }
}
long long n;
void work(){
     n=read();
     if(n<=2){
         printf("1\n");return;
     }
     ans.a[1][1]=1,ans.a[1][2]=1;
     base.a[1][1]=1,base.a[1][2]=1,base.a[2][1]=1,base.a[2][2]=0;
     qpow(n-2);
     printf("%lld\n",ans.a[1][1]%mod);
}
int main(){
    work();
    return 0;
}

\(update2022.02.22\)

今天做了另一道板子,LuoguP1939,发现上一次只是学会了一点点,还有相当一部分的问题没有解决

为什么初始矩阵要定义为\(\begin{bmatrix} F[i] & F[i-1] \end{bmatrix}\)?

上文中有提到,矩阵乘法是用来加速递推的,既然是递推就要从上一个状态转移到下一个状态,那么初始矩阵里存的即为状态

以P1939为例,定义\(F[i]\)为数列的第i个数,显然,其上一个状态为\(F[i-1]\),尝试把初始矩阵定义为\(\begin{bmatrix} F[i] & F[i-1] \end{bmatrix}\),那么他的上一个状态即为\(\begin{bmatrix} F[i-1] & F[i-2] \end{bmatrix}\),根据递推式\(F[i]=F[i-1]+F[i-3]\),还需要往前再找一个状态\(\begin{bmatrix} F[i-2] & F[i-3] \end{bmatrix}\),这时就会出现问题,无法直接从上一个状态递推下来。

那么把初始矩阵定义为\(\begin{bmatrix} F[i] & F[i-1]&F[i-2] \end{bmatrix}\),上一个状态为\(\begin{bmatrix} F[i-1] & F[i-2]&F[i-3] \end{bmatrix}\),乘上一个base就可以实现转移

base:

\(\begin{bmatrix} 1&1&0\\0&0&1\\1&0&0\end{bmatrix}\)

Code

struct Matrix{
    int lenx,leny;//行数,列数
    int a[10][10];
    Matrix(){
        memset(a,0,sizeof(a));
    }
    friend Matrix operator*(Matrix A,Matrix B){
        Matrix res;
        for(int i=1;i<=A.lenx;++i){
            for(int j=1;j<=B.leny;++j){
               for(int k=1;k<=B.lenx;++k){
                   res.a[i][j]=(res.a[i][j]+A.a[i][k]*B.a[k][j])%mod;                                         
               }
            }
        }
        res.lenx=A.lenx,res.leny=B.leny;
        return res;
    }
    void qpow(Matrix base,Matrix &ans,int x){
        while(x){
            if(x&1)ans=ans*base;
            base=base*base,x>>=1;
        }
    }
}mat;
int n;
void work(){
     Matrix base,ans;
     ans.a[1][1]=1,ans.a[1][2]=1,ans.a[1][3]=1,ans.lenx=1,ans.leny=3;
     base.a[1][1]=1,base.a[1][2]=1,base.a[1][3]=0,base.a[2][1]=0,base.a[2][2]=0,base.a[2][3]=1,base.a[3][1]=1,base.a[3][2]=0,base.a[3][3]=0;
     base.lenx=3,base.leny=3;
     n=read();
     if(n<=3){
         printf("1\n");return;
     }
     mat.qpow(base,ans,n-3);
     printf("%lld\n",ans.a[1][1]);
}
signed main(){
    int T=read();
    while(T--)work();
    return 0;
}
posted @ 2022-05-13 09:22  Chano_sb  阅读(156)  评论(0)    收藏  举报