矩阵乘法与矩阵加速

矩阵乘法

一个 \(n\)\(m\) 列的矩阵 \(A\) 乘以一个 \(m\)\(k\) 列的矩阵 \(B\),所得到的矩阵 \(C\) 的大小是 \(n\)\(k\) 列,计算公式如下:

\[C_{i,j} = \sum_{i = 1}^{m} A_{i, k} \times B_{k, j} \]

其实还可以这样记:\(A\) 储存着 \(n\)\(m\) 维的行向量\(B\) 储存着 \(k\)\(m\) 维的列向量\(C_{i, j}\) 就是 \(A\) 的第 \(i\) 个行向量与 \(B\) 的第 \(j\) 个列向量的数量积


矩阵加速

其实就是求一个矩阵的 \(n\) 次方时,用快速幂优化。
这时可以先将矩阵定义为一个结构体,再重载一下矩阵的乘法运算符就好了。
模板如下(矩阵大小为 \(n \times n\)):

struct matrix {
	int a[maxn][maxn];
	void init() {memset(a, 0, sizeof(a));}
	matrix operator * (const matrix& x) const {
		matrix res; res.init();
		for (int i = 1; i <= n; ++i)
			for (int j = 1; j <= n; ++j)
				for (int k = 1; k <= n; ++k)
					res.a[i][j] = (res.a[i][j] + a[i][k] * x.a[k][j] % mod) % mod;
		return res;
	}
} A;
matrix qpow(matrix x, int y) {
	matrix res; res.init();
	// res 初始化成单位矩阵, 即在【左上-右下对角线】上的位置为 1 
	// 单位矩阵乘以任何矩阵都会得到那个矩阵本身, 就跟任何数乘以 1 都会得到它自身一样 
	for (int i = 1; i <= n; ++i) res.a[i][i] = 1; 
	for (; y; y >>= 1, x = x * x)
		if (y & 1) res = res * x;
	return res;
}

矩阵加速解题思路

题目特征:

  1. 矩阵加速的题,基本都是动态规划(有的是数列递推,本质上都是需要状态转移的),只不过是用矩阵加速来加快转移。所以一道题如果是动态规划,那么可能就是需要矩阵加速来优化。
  2. 涉及状态的一个变量很大。就比如说让你计算斐波那契数列的第 \(n\)\(f(n)\),但是 \(n\) 的大小达到了 \(10^9\) 甚至更大,此时就需要矩阵加速来把它优化到 \(log(n)\)

步骤:

  1. 写出动态规划朴素的转移方程;
  2. 观察转移方程,看看是否可以设计一个矩阵加速来转移。

例子1:求斐波那契数列的第 \(n\)\(mod \ 10^9 + 7\) 的值,\(n \leq 10^9\)

先列出朴素转移方程:\(f(n) = f(n - 1) + f(n - 2)\)

现在考虑设计状态矩阵(储存的就是我们要求的值),发现 \(f(n)\) 是由 \(f(n - 1)\)\(f(n - 2)\) 转移过来的,所以状态矩阵要包含这两个变量。

\[\begin{bmatrix} f(n-1) \\ \\ f(n - 2) \\ \end{bmatrix} \]

现在考虑设计转移矩阵,两个矩阵乘出来还要是一个 \(2 \times 1\) 的状态矩阵,那么这个转移矩阵就只能是 \(2 \times 2\) 大小的。设出矩阵中的每个未知数 \(a, b, c, d\),进行求解:

\[\begin{bmatrix} a & b \\ \\ c & d \\ \end{bmatrix} \times \begin{bmatrix} f(n-1) \\ \\ f(n - 2) \\ \end{bmatrix} = \begin{bmatrix} a \times f(n-1) + b \times f(n - 2) \\ \\ c \times f(n-1) + d \times f(n - 2) \\ \end{bmatrix} = \begin{bmatrix} 1 \times f(n-1) + 1 \times f(n - 2) \\ \\ 1 \times f(n-1) + 0 \times f(n - 2) \\ \end{bmatrix} = \begin{bmatrix} f(n) \\ \\ f(n-1) \\ \end{bmatrix} \]

解得转移矩阵为:

\[\begin{bmatrix} 1 & 1 \\ \\ 1 & 0 \\ \end{bmatrix} \]

而我们又知道 \(f(1) = 1,f(2) = 1\)(也就是初始矩阵),那么我们就可以得到:

\[\begin{bmatrix} f(n) \\ \\ f(n-1) \\ \end{bmatrix}= \begin{bmatrix} 1 & 1 \\ \\ 1 & 0 \\ \end{bmatrix}^{n - 2} \times \begin{bmatrix} 1 \\ \\ 1 \\ \end{bmatrix} \]

这里之所以是 \(n - 2\) 次方是因为我们已经知道了前两项,要求第 \(n\) 项时,就是要乘 \(n - 2\) 次转移矩阵。

代码如下:

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

const int maxn = 5e5 + 7;
const int mod  = 1e9 + 7;

int n, mod;
struct matrix {
	ll a[3][3];
	void init() {memset(a, 0, sizeof(a));}
	matrix operator * (const matrix& x) const {
		matrix res; res.init();
		for (int i = 1; i <= 2; ++i)
			for (int j = 1; j <= 2; ++j)
				for (int k = 1; k <= 2; ++k)
					res.a[i][j] = (res.a[i][j] + a[i][k] * x.a[k][j] % mod) % mod;
		return res;
	}
};

int qpow(matrix x, int y) {
	if (y <= 0) return 1;
	matrix base; base.init();
	for (int i = 1; i <= 2; ++i) base.a[i][i] = 1;
	for (; y; y >>= 1, x = x * x)
		if (y & 1) base = base * x;
	/*
		假设转移矩阵的 n - 2 次方求出来后为:
			[a, b]
			[c, d]
		它乘以出事矩阵就是:
			[a, b] [1]
			[c, d] [1]
		乘出来得:
			[a + b]
			[c + d]
		那么 f(n) = a + b
	*/
	return (base.a[1][1] + base.a[1][2]) % mod;
}
int main() {
	scanf("%d", &n);
	matrix trans; trans.init();
	trans.a[1][1] = trans.a[1][2] = trans.a[2][1] = 1;
	trans.a[2][2] = 0;
	printf("%d", qpow(trans, n - 2));
	return 0;
}

例二:「一本通 6.5 例 4」佳佳的 Fibonacci

题目描述:\(f(i)\) 为斐波那契数列第 \(i\) 项,\(T(n) = \sum_{i=1}^{n} i \times f(i)\)。给定 \(n, m\)\(T(n) \ mod \ m\) 的值。

首先要知道一个公式:设 \(S(n)\) 为斐波那契数列的前 \(n\) 项和,那么 \(S(n) = f(n + 2) - 1\)
那么就可以推出来 \(T(n) = \sum_{i = 1}^{n} S(n) - S(i - 1)\),剩下的就是化简,最后得到 \(T(n) = n \times f(n + 2) - f(n + 3) + 2\),再用矩阵快速幂就行。

代码如下:

#include <bits/stdc++.h> 

using namespace std;

typedef long long ll;

const int maxn = 100 + 7;

/*
斐波那契数列前 n 项和:S(n) = F(n + 2) - 1 
*/

ll n, mod;
struct Matrix {
	ll a[5][5];
	void init() {memset(a, 0, sizeof(a));}
	Matrix operator * (const Matrix& x) const {
		Matrix res; res.init();
		for (int i = 1; i <= 2; ++i)
			for (int j = 1; j <= 2; ++j)
				for (int k = 1; k <= 2; ++k)
					res.a[i][j] = (res.a[i][j] + a[i][k] * x.a[k][j] % mod) % mod;
		return res;
	}
} f;
ll qpow(Matrix x, int y) {
	Matrix base; base.init();
	for (int i = 1; i <= 2; ++i) base.a[i][i] = 1;
	for (; y; y >>= 1, x = x * x)
		if (y & 1) base = base * x;
	return base.a[1][1];
}
ll F(int x) {
	f.a[1][1] = f.a[1][2] = f.a[2][1] = 1;
	f.a[2][2] = 0;
	return qpow(f, x - 1);
}
int main() {
	scanf("%lld%lld", &n, &mod);
	
	// T(n) = n * F(n + 2) - F(n + 3) + 2
	ll ans = (
			  (
	           (n % mod) * (F(n + 2) % mod) % mod - 
			   (F(n + 3) % mod) + (2 % mod)
			  ) % mod + mod
			 ) % mod;
	printf("%lld\n", ans);
	return 0;
}
/*
2 1000000007
*/

例题

洛谷 P4159 [SCOI2009] 迷路
洛谷 P3193 [HNOI2008] GT考试

posted @ 2025-02-11 10:26  syzyc  阅读(23)  评论(0)    收藏  举报