Loading

【SHTSC2013】超级跳马

题目链接

题目大意

\(n*m\)的网格上,一只马在点\((1,1)\),点\((i,j)\)可以跳到\((i-1,j+k)\)\((i,j+k)\)\((i+1,j+k)\),其中\(k\)是一个奇数,求跳到\((n,m)\)的方案数。

解析

设:
\(f_{i,j}\)表示跳到\((j,i)\)的方案数(为了方便我换了一下\(i,j\)的顺序,相当于按列为阶段转移)
\(a_{i,j}=\sum_{k=0}f_{i,j-2k}\)
\(b_{i,j}=\sum_{k=0}f_{i,j-2k-1}\)

得到三者之间关系是:

\(f_{i,j}=a_{i-1,j}+a_{i-1,j-1}+a_{i-1,j+1}\)
\(a_{i,j}=b_{i-1,j}+f_{i,j}\)
\(b_{i,j}=a_{i-1,j}\)

仅根据这三条式子,就能够用\(O(nm)\)的时间复杂度求出\(f_{i,j}\)了。

\(m\leq 10^9\),还需优化。

原来是三个状态的转移,现在我们变一下式子:
\(a_{i,j}=a_{i-2,j}+a_{i-1,j}+a_{i-1,j-1}+a_{i-1,j+1}\)

现在只剩\(a\)一个状态的转移了,最后求\(f_{i,j}=a_{i,j}-a_{i-2,j}\)即可。

由于\(m\)很大,考虑使用矩阵乘法。

转移矩阵的构造方法很巧妙,我们把\(a_{i-1,1 \sim n}\)还有\(a_{i-2,1 \sim n}\)放在初始矩阵的第一行,其它位置全部填\(0\)

例如\(n=3\)时,初始矩阵为:
\(\begin{Bmatrix} a_{i-1,1} & a_{i-1,2} & a_{i-1,3} & a_{i-2,1} & a_{i-2,2} & a_{i-2,3} \\ 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 \end{Bmatrix} \quad\)

转移矩阵为:
\(\begin{Bmatrix} 1 & 1 & 0 & 1 & 0 & 0 \\ 1 & 1 & 1 & 0 & 1 & 0 \\ 0 & 1 & 1 & 0 & 0 & 1 \\ 1 & 0 & 0 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 & 0 & 0 \\ 0 & 0 & 1 & 0 & 0 & 0 \end{Bmatrix} \quad\)

这样时间复杂度降为\(O(n^3logm)\),问题解决了。

Code

#include <cstdio>
#include <cstring>

const int N = 57, P = 30011;
int max(int a, int b) { return a > b ? a : b; }
int min(int a, int b) { return a < b ? a : b; }

int n, m;

struct matrix
{
	int num[N * 2][N * 2];
	matrix operator*(matrix a)
	{
		matrix c; memset(c.num, 0, sizeof(c.num));
		for (int i = 0; i < 2 * n; i++)
			for (int j = 0; j < 2 * n; j++)
				for (int k = 0; k < 2 * n; k++)
					c.num[i][j] = (c.num[i][j] + num[i][k] * a.num[k][j] % P) % P;
		return c;
	}
} bas, mov, ret;

int getit(int m, int n)
{
	if (m <= 0) return 0;
	memset(bas.num, 0, sizeof(bas.num));
	memset(mov.num, 0, sizeof(mov.num));
	memset(ret.num, 0, sizeof(ret.num));
	for (int j = 0; j < n; j++) for (int i = max(j - 1, 0); i <= min(j + 1, n - 1); i++) mov.num[i][j] = 1;
	for (int i = n; i < 2 * n; i++) mov.num[i][i - n] = 1;
	for (int j = n; j < 2 * n; j++) mov.num[j - n][j] = 1;
	bas.num[0][0] = 1;
	for (int i = 0; i < 2 * n; i++) ret.num[i][i] = 1;
	m--;
	while (m)
	{
		if (m & 1) ret = ret * mov;
		mov = mov * mov, m >>= 1;
	}
	bas = bas * ret;
	return bas.num[0][n - 1];
}

int main()
{
	scanf("%d%d", &n, &m);
	printf("%d\n", (getit(m, n) - getit(m - 2, n) + P) % P);
	return 0;
}
posted @ 2019-07-02 15:58  gz-gary  阅读(207)  评论(0编辑  收藏  举报