poj - 3233

poj - 3233

题目大意:

给定一个 \(n * n\) 的矩阵 \(A\) 和 一个正整数 \(K\), 求幂矩阵和 \(sum(k) = A + A^1 + A^2 + A^3 +······+A^k\)

思路:

\(S_k = sum(k)\)\(S_k = S_{k - 1} * A + A\)。等式右边有两项,那么就构造一个 \(2 * 2\) 的矩阵出来。

很容易构造出 \(S_k = \begin{bmatrix}S_{k-1} & A\\\end{bmatrix} * \begin{bmatrix}A & 0 \\ E & E\end{bmatrix} ^ {k - 1}\)

AC代码:

#include<iostream>
#include<string.h>
#define ll long long
using namespace std;

const int N = 1e2 + 10;
// const int mod = 1e9 + 7;
ll k, mod;
int sz;
struct mat
{
    
    ll m[N][N];
    mat() {memset(m, 0, sizeof(m));}
    void clear() {memset(m, 0, sizeof(m));}
    mat operator - (const mat& T) const 
    {
        mat res;
        for (int i = 1; i <= sz; ++i)
        {
            for (int j = 1; j <= sz; ++j)
            {
                res.m[i][j] = ((m[i][j] - T.m[i][j]) % mod + mod) % mod;
            }
        }

        return res;
    }

    mat operator + (const mat& T) const
    {
        mat res;
        for (int i = 1; i <= sz; ++i)
        {
            for (int j = 1; j <= sz; ++j)
            {
                res.m[i][j] = (m[i][j] + T.m[i][j]) % mod;
            }
        }

        return res;
    }

    mat operator * (const mat& T) const
    {
        mat res;
        ll r;
        for (int i = 1; i <= sz; ++i)
        {
            for (int k = 1; k <= sz; ++k)
            {
                r = m[i][k];
                for (int j = 1; j <= sz; ++j)
                {
                    res.m[i][j] += r * T.m[k][j],
                    res.m[i][j] %= mod; 
                }
            }
        }
        return res;
    }

    mat operator ^(ll x) const 
    {
        mat res, bas;
        for (int i = 1; i <= sz; ++i) res.m[i][i] = 1;
        for (int i = 1; i <= sz; ++i)
        {
            for (int j = 1; j <= sz; ++j) bas.m[i][j] = m[i][j] % mod;
        }
        while (x)
        {
            if (x & 1) res = res * bas;
            bas = bas * bas;
            x >>= 1;
        }
        return res;
    }
};

mat m1, m2;
int n;
int main()
{
    cin >> n >> k >> mod;
    sz = 2 * n;
    for (int i = 1; i <= n; i++)
    {
        for (int j = 1; j <= n; j++) 
        {
            cin >> m1.m[i][j];
            m1.m[i][j + n] = m1.m[i][j];
            m2.m[i][j] = m1.m[i][j];
        }
    }
    for (int i = 1; i <= n; i++) m2.m[i + n][i] = m2.m[i + n][i + n] = 1;
    m1 = m1 * (m2 ^ (k - 1));
    for (int i = 1; i <= n; i++)
    {
        for (int j = 1; j <= n; j++)
        {
            cout << m1.m[i][j] % mod <<' ';
        }
        cout <<"\n";
    }
}
posted @ 2023-01-04 18:03  llinzy  阅读(28)  评论(0)    收藏  举报