[loj6388] 「THUPC2018」赛艇 / Citing

Description

​ 给你一个\(~n \times m~\)\(~01~\)矩阵,一个人在这个矩阵中走了\(~k~\)步,每一次都往四联通方向中的一个走一步。给定这个人每一步走的方向,已知这个人经过的每一步都没有经过原矩阵中\(~1~\)的位置。问合法的起点有多少种?保证至少有一组解。\(~1 \leq n, m \leq 1500, ~k \leq 5 \times 10 ^ 6~\).

Solution

​ 不难发现那条路径通过补全\(~0~\)之后其实就是一个\(~01~\)矩阵,其中的\(~1~\)就是原路径。问题变成了把该矩阵放在原矩阵中(严格内含)不产生冲突的方案数,实质上就是起来全是\(~0~\)的方案数。考虑怎么快速求这个问题。把该矩阵通过补\(~0~\)变成和原矩阵一样大的规模,把两个矩阵都拉成长度为\(~n \times m~\)的序列,倒序一个序列做\(~FFT~\)\(~NTT~\)在看对应位置上是否为\(~0~\)统计答案即可。至于这样为什么是对的,可以考虑这个对应位置的数代表的东西到底是什么,卷积中\(~ans_i~\)代表下标和为\(~i~\)的各项乘积之和,由于之前做过一个区间反转,所以这个\(~ans_i~\)就代表路径矩阵在原矩阵中起始位置为\(~i~\)时矩阵各项匹配起来的乘积的和,而在只有\(~0, 1~\)的情况下,乘法和或的运算法则一样。所以当\(~ans_i~\)\(~0~\)时,就代表这个匹配位置是合法的,因为没有任何一个\(~1~\)同位。

Code

#include<bits/stdc++.h>
#define For(i, j, k) for(int i = j; i <= k; ++i)
#define Forr(i, j, k) for(int i = j; i >= k; --i)
using namespace std;

inline int read() {
	int x = 0, p = 1; char c = getchar();
	for(; !isdigit(c); c = getchar()) if(c == '-') p = -1;
	for(; isdigit(c); c = getchar()) x = (x << 1) + (x << 3) + (c ^ 48);
	return x *= p;
}

inline void File() {
#ifndef ONLINE_JUDGE
	freopen("loj6388.in", "r", stdin);
	freopen("loj6388.out", "w", stdout);
#endif
}

const int N = 1500 + 10, M = (N * N) << 2, mod = 998244353;
int a[M], b[M], rev[M], powg[M], invg[M], k;
int n, m, cnt1, cnt2, siz, len, bit, c[N << 1][N << 1];
char ss[M]; 

inline int qpow(int a, int b) {
	static int res;
	for (res = 1; b; a = 1ll * a * a % mod, b >>= 1)
		if (b & 1) res = 1ll * res * a % mod;
	return res;
}

inline void NTT(int *a, int flag) {
	For(i, 0, siz - 1) if (rev[i] > i) swap(a[rev[i]], a[i]);
	for (int i = 2; i <= siz; i <<= 1) {
		int wn = flag ? powg[i] : invg[i];
		for (int j = 0; j < siz; j += i) {
			int w = 1;
			for (int k = 0; k < (i >> 1); ++ k, w = 1ll * w * wn % mod) {
				int x = a[j + k], y = 1ll * w * a[j + k + (i >> 1)] % mod;
				a[j + k] = (x + y) % mod, a[j + k + (i >> 1)] = (x - y + mod) % mod;
			}
		}
	}
	if (!flag) {
		int g = qpow(siz, mod - 2);
		For(i, 0, siz) a[i] = 1ll * a[i] * g % mod;
	}
}

int main() {
	File();
	n = read(), m = read(), k = read();

	For(i, 1, n) {
		scanf("%s", ss + 1);
		For(j, 1, m) a[(i - 1) * m + j - 1] = ss[j] - 48;
	}
	cnt1 = n * m - 1;

	int x2 = n, y2 = m, x0 = n, y0 = m, lx = n, ly = m;
	scanf("%s", ss + 1), c[lx][ly] = 1; 
	For(i, 1, k) {
		if (ss[i] == 'w') c[-- lx][ly] = 1;
		if (ss[i] == 'a') c[lx][-- ly] = 1;
		if (ss[i] == 's') c[++ lx][ly] = 1;
		if (ss[i] == 'd') c[lx][++ ly] = 1;
		x0 = min(x0, lx), y0 = min(y0, ly);
		x2 = max(x2, lx), y2 = max(y2, ly);
	}

	For(i, x0, x0 + n - 1) For(j, y0, y0 + m - 1) b[cnt1 - (cnt2 ++)] = c[i][j];
	-- cnt2;

	len = cnt1 + cnt2;
	for (siz = 1; siz <= len; siz <<= 1) ++ bit;
	For(i, 0, siz - 1) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)); 

	int g = qpow(3, mod - 2);
	for (int i = 1; i <= siz; i <<= 1) {
		invg[i] = qpow(g, (mod - 1) / i);
		powg[i] = qpow(3, (mod - 1) / i);
	}

	NTT(a, 1), NTT(b, 1);
	For(i, 0, siz - 1) a[i] = 1ll * a[i] * b[i] % mod;
	NTT(a, 0);
	
	int ans = 0;
	For(i, 1, n - (x2 - x0)) For(j, 1, m - (y2 - y0)) 
		if (a[cnt1 + (i - 1) * m + j - 1] == 0) ++ ans; 

	cout << ans << endl;
	return 0;
}
posted @ 2018-08-22 19:11  LSTete  阅读(355)  评论(0编辑  收藏  举报