NFLSOJ829 【2020六校联考WC #1】LJJ的生日礼物

NFLSOJ829 【2020六校联考WC #1】LJJ的生日礼物

题目大意

题目链接

一个长度为 \(L\) 的序列。有 \(K\) 种颜色。你需要给每个位置染上一种颜色,使得没有距离 \(\leq 2\) 的两个位置颜色相同。有 \(N\) 个位置 \(p_{1\dots N}\),它们的颜色已经确定了,分别是 \(c_{1\dots N}\)。请求出给剩下位置染色的方案数。答案对 \(10^9 + 7\) 取模。

数据范围:\(0\leq N\leq 1000\)\(1\leq K\leq 10^9\)\(\max(1, N)\leq L\leq 10^9\)。保证 \(p_1 < p_2\dots < p_N\)

本题题解

考虑朴素的 DP。设 \(\mathrm{dp}_1(i, j, k)\) 表示考虑了序列的前 \(i\) 个位置(\(2\leq i\leq L\)),第 \(i - 1\) 个位置颜色为 \(j\),第 \(i\) 个位置颜色为 \(k\) 的染色方案数。时间复杂度 \(\mathcal{O}(LK^3)\)

\(c_{1\dots N}\) 中总共出现了 \(W\) 种颜色(\(W\leq N\))。可以把剩下所有颜色,视为一种特殊颜色。具体来说,在 DP 状态里,如果 \(j\leq W\),它表示的是某一种出现过的颜色,正常转移即可;否则 \(j = W + 1\),即特殊颜色,这个状态表示:【$j = $ 剩下任意一种颜色时,的 DP 值】之和(这些颜色的 DP 值显然是完全一样的,它们的转移也是完全一样的),于是转移时乘以系数 \(K - W\) 即可。时间复杂度 \(\mathcal{O}(LW^3) = \mathcal{O}(LN^3)\)

继续优化,我们要让时间复杂度摆脱 \(L\),就必须抛弃上述的、逐个位置 DP 的想法,直接在 \(N\) 个关键点之间进行转移。于是有了一个新的状态设计:设 \(\mathrm{dp}_2(i, j)\) 表示考虑了前 \(i\) 个关键点,\(p_i + 1\) 位置上颜色为 \(j\) 的方案数。转移时,枚举 \(p_{i + 1} + 1\) 的颜色 \(j'\)。要从 \(\mathrm{dp}_2(i, j)\) 转移到 \(\mathrm{dp}_2(i + 1, j')\)。我们想快速求出转移系数。问题可以转化为:有一段长度为 \(p_{i + 1} - p_{i} + 2\) 的序列,开头两个位置颜色分别为 \(c_i, j\),末尾两个位置颜色分别为 \(c_{i + 1}, j'\),求给中间其他位置染色的方案数

发现这个问题的答案,只和 \(c_{i}, j, c_{i + 1}, j'\) 这四个颜色两两是否相等有关,与它们具体是什么无关。换句话说,等价的情况只有 \(\mathrm{Bell}(4) = 15\) 种(实际更少,因为例如 \(c_i = j\) 等情况是不合法的)。并且因为长度为 \(p_{i + 1} - p_{i} + 2\),共 \(N - 1\) 种,所以只需要对这 \((N - 1)\cdot\mathrm{Bell}(4)\) 个问题分别预处理答案即可。

如何预处理答案?朴素的想法还是做第一个 DP。即 \(\text{dp}_1(i, j, k)\) 表示前 \(i\) 个位置,最后两个位置颜色分别为 \(j,k\) 的方案数。关键的颜色在这里只有 \(4\) 种(\(c_i, j, c_{i + 1}, j'\)),其他颜色可以合并为一种特殊颜色,故状态数是 \(5\times 5 = 25\) 的。进一步,甚至 \(c_i, j\) 这两种颜色也不重要,我们只把 \(c_{i + 1}, j'\) 视为关键颜色,就可以完成 DP。状态数优化为 \(3\times 3 = 9\) 种。减去连续两个颜色相等,这样不合法的状态后,只剩 \(7\) 种。朴素 DP 时间复杂度 \(\mathcal{O}(L \cdot 7)\)。因为状态数 \(7\) 很小,而长度 \(L\) 很大,考虑用矩阵快速幂,可以优化为 \(\mathcal{O}(7^3\log L)\)

于是,可以在 \(\mathcal{O}(N\cdot \mathrm{Bell}(4)\cdot 7^3\cdot \log L)\) 的时间复杂度内完成所有预处理。接下来用预处理的信息,可以 \(\mathcal{O}(1)\) 回答从 \(\mathrm{dp}_2(i, j)\)\(\mathrm{dp}_2(i + 1, j')\) 转移的系数。现在,这个 DP 的时间复杂度是 \(\mathcal{O}(NK^2)\)

发现对所有 \(j'\neq j, j'\neq c_i\)\(\mathrm{dp}_2(i,j)\) 对它们的贡献是一样的。于是就没必要用两层循环,先枚举 \(j\) 再枚举 \(j'\) 了。可以先求出系数和,再加到所有 \(j'\) 上。时间复杂度 \(\mathcal{O}(NK)\)

最后,把除了 \(c_{1\dots N}\) 外,其他颜色视为一种特殊颜色。可以优化为 \(\mathcal{O}(N^2)\)

总时间复杂度 \(\mathcal{O}(N\cdot \mathrm{Bell}(4)\cdot 7^3\cdot \log L + N^2)\)

参考代码

// problem: NFLSOJ829
#include <bits/stdc++.h>
using namespace std;

#define mk make_pair
#define fi first
#define se second
#define SZ(x) ((int)(x).size())

typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;

template<typename T> inline void ckmax(T& x, T y) { x = (y > x ? y : x); }
template<typename T> inline void ckmin(T& x, T y) { x = (y < x ? y : x); }

const int MAXN = 1000;
const int MOD = 1e9 + 7;

inline int mod1(int x) { return x < MOD ? x : x - MOD; }
inline int mod2(int x) { return x < 0 ? x + MOD : x; }
inline void add(int &x, int y) { x = mod1(x + y); }
inline void sub(int &x, int y) { x = mod2(x - y); }
inline int pow_mod(int x, int i) {
	int y = 1;
	while (i) {
		if (i & 1) y = (ll)y * x % MOD;
		x = (ll)x * x % MOD;
		i >>= 1;
	}
	return y;
}

int n, K, L;
int p[MAXN + 5], c[MAXN + 5];
int cols[MAXN + 5], cnt_col;

bool special_cases() {
	for (int i = 1; i < n; ++i) {
		if (p[i + 1] - p[i] <= 2 && c[i] == c[i + 1]) {
			cout << 0 << endl;
			return true;
		}
	}
	
	if (K == 0) {
		cout << 0 << endl;
		return true;
	}
	if (K == 1) {
		cout << (L == 1) << endl;
		return true;
	}
	if (K == 2) {
		if (L == 1) {
			cout << (n == 1 ? 1 : 2) << endl;
		} else if (L == 2) {
			if (n == 0) {
				cout << 2 << endl;
			} else if (n == 1) {
				cout << 1 << endl;
			} else {
				cout << (c[1] != c[2]) << endl;
			}
		} else { // L > 2
			cout << 0 << endl;
		}
		return true;
	}
	if (n == 0) {
		if (L == 1) {
			cout << K << endl;
			return true;
		}
		cout << (ll)K * (K - 1) % MOD * pow_mod(K - 2, L - 2) % MOD << endl;
		return true;
	}
	if (p[1] == L) {
		if (L == 1) {
			cout << 1 << endl;
			return true;
		}
		cout << (ll)(K - 1) * pow_mod(K - 2, L - 2) % MOD << endl;
		return true;
	}
	
	cerr << "* no special case" << endl;
	return false;
}

int dp[MAXN + 5][MAXN + 5]; // 第 i 个固定点, a[p[i] + 1] = j

/*
int g[3][3][55][3][3];
void brute_force_dp(int g[55][3][3], int x, int y) {
	g[2][x][y] = 1;
	for (int i = 3; i <= 50; ++i) {
		add(g[i][1][2], (ll)g[i - 1][0][1] * (K - 2) % MOD);
		add(g[i][0][2], (ll)g[i - 1][1][0] * (K - 2) % MOD);
		
		add(g[i][2][1], g[i - 1][0][2]);
		if (K > 3)
		add(g[i][2][2], (ll)g[i - 1][0][2] * (K - 3) % MOD);
		add(g[i][0][1], g[i - 1][2][0]);
		if (K > 3)
		add(g[i][0][2], (ll)g[i - 1][2][0] * (K - 3) % MOD);
		
		add(g[i][2][0], g[i - 1][1][2]);
		if (K > 3)
		add(g[i][2][2], (ll)g[i - 1][1][2] * (K - 3) % MOD);
		add(g[i][1][0], g[i - 1][2][1]);
		if (K > 3)
		add(g[i][1][2], (ll)g[i - 1][2][1] * (K - 3) % MOD);
		
		add(g[i][2][1], g[i - 1][2][2]);
		add(g[i][2][0], g[i - 1][2][2]);
		if (K > 4)
		add(g[i][2][2], (ll)g[i - 1][2][2] * (K - 4) % MOD);
	}
}
*/ // 把以上的暴力 DP 换成矩阵快速幂!
/*
g[x][y] 表示最末尾的两种颜色是 x, y
0, 1 是 p[i + 1], p[i + 1] + 1 的颜色, 2 代表其他颜色

状态共有 7 种:

g[0][1] -> mat[0]
g[1][0] -> mat[1]
g[0][2] -> mat[2]
g[2][0] -> mat[3]
g[1][2] -> mat[4]
g[2][1] -> mat[5]
g[2][2] -> mat[6]
*/
struct Matrix {
	int a[7][7];
	void identity() {
		for (int i = 0; i < 7; ++i)
			for (int j = 0; j < 7; ++j)
				a[i][j] = (i == j);
	}
	void clear() {
		for (int i = 0; i < 7; ++i)
			for (int j = 0; j < 7; ++j)
				a[i][j] = 0;
	}
	Matrix() {
		clear();
	}
};
Matrix operator * (const Matrix& X, const Matrix& Y) {
	Matrix Z;
	for (int i = 0; i < 7; ++i) {
		for (int j = 0; j < 7; ++j) {
			for (int k = 0; k < 7; ++k) {
				Z.a[i][j] = ((ll)Z.a[i][j] + (ll)X.a[i][k] * Y.a[k][j]) % MOD;
			}
		}
	}
	return Z;
}
Matrix mat_pow(Matrix X, int i) {
	Matrix Y;
	Y.identity();
	while (i) {
		if (i & 1) Y = Y * X;
		X = X * X;
		i >>= 1;
	}
	return Y;
}

Matrix Trans;
int mat_idx[3][3];
int f[MAXN + 5][3][3][3][3];
map<int, int> id;
int cnt_id;

int id_len[MAXN + 5];

int makef(int len) {
	if (id.count(len)) return id[len];
	id[len] = ++cnt_id;
	assert(len >= 2);
	
	Matrix A = mat_pow(Trans, len - 2);
	for (int i = 0; i <= 2; ++i) {
		for (int j = 0; j <= 2; ++j) {
			if ((i == 0 && j == 0) || (i == 1 && j == 1))
				continue;
			Matrix B;
			B.a[0][mat_idx[i][j]] = 1;
			B = B * A;
			for (int x = 0; x <= 2; ++x) {
				for (int y = 0; y <= 2; ++y) {
					if ((x == 0 && y == 0) || (x == 1 && y == 1))
						continue;
					f[cnt_id][i][j][x][y] = B.a[0][mat_idx[x][y]];
				}
			}
		}
	}
	return cnt_id;
}
void init() {
//	memset(g, 0, sizeof(g));
//	for (int i = 0; i <= 2; ++i) for (int j = 0; j <= 2; ++j) brute_force_dp(g[i][j], i, j);
	
	// 构造转移矩阵:
	Trans.clear();
	Trans.a[0][4] = K - 2;            // g[0][1] -> g[1][2]
	Trans.a[1][2] = K - 2;            // g[1][0] -> g[0][2]
	Trans.a[2][5] = 1;                // g[0][2] -> g[2][1]
	if (K > 3) Trans.a[2][6] = K - 3; // g[0][2] -> g[2][2]
	Trans.a[3][0] = 1;                // g[2][0] -> g[0][1]
	if (K > 3) Trans.a[3][2] = K - 3; // g[2][0] -> g[0][2]
	Trans.a[4][3] = 1;                // g[1][2] -> g[2][0]
	if (K > 3) Trans.a[4][6] = K - 3; // g[1][2] -> g[2][2]
	Trans.a[5][1] = 1;                // g[2][1] -> g[1][0]
	if (K > 3) Trans.a[5][4] = K - 3; // g[2][1] -> g[1][2]
	Trans.a[6][5] = 1;                // g[2][2] -> g[2][1]
	Trans.a[6][3] = 1;                // g[2][2] -> g[2][0]
	if (K > 4) Trans.a[6][6] = K - 4; // g[2][2] -> g[2][2]
	
	mat_idx[0][1] = 0;
	mat_idx[1][0] = 1;
	mat_idx[0][2] = 2;
	mat_idx[2][0] = 3;
	mat_idx[1][2] = 4;
	mat_idx[2][1] = 5;
	mat_idx[2][2] = 6;
	
	cnt_id = 0;
	id.clear();
	for (int i = 1; i < n; ++i) {
		id_len[i] = makef(p[i + 1] - p[i] + 2);
	}
	if (p[n] == L) {
		makef(p[n] - p[n - 1] + 1);
	}
}
int calc(int id, int s1, int s2, int t1, int t2) {
	int c1 = (s1 == t1 ? 0 : (s1 == t2 ? 1 : 2));
	int c2 = (s2 == t1 ? 0 : (s2 == t2 ? 1 : 2));
	return f[id][c1][c2][0][1];
}
void solve_case() {
	cin >> n >> K >> L;
	for (int i = 1; i <= n; ++i) {
		cin >> p[i] >> c[i];
		++p[i];
	}
	
	if (special_cases())
		return;
	
	cnt_col = 0;
	for (int i = 1; i <= n; ++i) {
		cols[++cnt_col] = c[i];
	}
	sort(cols + 1, cols + cnt_col + 1);
	cnt_col = unique(cols + 1, cols + cnt_col + 1) - (cols + 1);
	for (int i = 1; i <= n; ++i) {
		c[i] = lower_bound(cols + 1, cols + cnt_col + 1, c[i]) - cols;
	}
	assert(cnt_col <= K);
	
	init();
	
	for (int i = 1; i <= n; ++i)
		for (int j = 1; j <= cnt_col + 1; ++j)
			dp[i][j] = 0;
	
	int w = pow_mod(K - 2, p[1] - 1);
	for (int j = 1; j <= cnt_col; ++j) {
		if (j == c[1])
			continue;
		dp[1][j] = w;
	}
	dp[1][cnt_col + 1] = (ll)w * (K - cnt_col) % MOD;
	
	for (int i = 1; i < n; ++i) {
		// dp[i][j1] -> dp[i + 1][j2]
		if (p[i] + 1 == p[i + 1]) {
			if (p[i + 1] == L) {
				cerr << "* last two places fixed" << endl;
				cout << dp[i][c[i + 1]] << endl;
				return;
			}
			for (int j2 = 1; j2 <= cnt_col; ++j2) {
				// a[p[i + 1] + 1] = j2
				if (j2 == c[i + 1] || j2 == c[i])
					continue;
				dp[i + 1][j2] = dp[i][c[i + 1]];
			}
			dp[i + 1][cnt_col + 1] = (ll)dp[i][c[i + 1]] * (K - cnt_col) % MOD;
			continue;
		}
		
		if (p[i + 1] == L) {
			cerr << "* last place fixed" << endl;
			
			int ans = 0;
			for (int j1 = 1; j1 <= cnt_col + 1; ++j1) {
				if (j1 == c[i])
					continue;
				if (p[i + 1] - p[i] <= 3 && j1 == c[i + 1])
					continue;
				
				int len = p[i + 1] - p[i] + 1;
				int l = id[len];
				
				if (c[i] == c[i + 1]) {
					add(ans, (ll)dp[i][j1] * (f[l][0][1][1][0] + f[l][0][1][2][0]) % MOD);
				} else {
					if (j1 == c[i + 1]) {
						add(ans, (ll)dp[i][j1] * (f[l][0][1][0][1] + f[l][0][1][2][1]) % MOD);
					} else {
						add(ans, (ll)dp[i][j1] * (f[l][0][2][0][1] + f[l][0][2][2][1]) % MOD);
					}
				}
			}
			cout << ans << endl;
			return;
		}
		
		
		int val = 0;
		for (int j1 = 1; j1 <= cnt_col + 1; ++j1) {
			if (j1 == c[i])
				continue;
			if (p[i + 1] - p[i] <= 3 && j1 == c[i + 1])
				continue;
			
			// j2 的特殊点: j1, c[i], cnt_col + 1
			
			
			int j2 = 1;
			while (j2 <= cnt_col && (j2 == j1 || j2 == c[i] || j2 == c[i + 1])) j2++; // 找到一般的 j2
			if (j2 <= cnt_col) {
				int _val = (ll)dp[i][j1] * calc(id_len[i], c[i], j1, c[i + 1], j2) % MOD;
				add(val, _val);
				if (j1 <= cnt_col && j1 != c[i + 1])
					sub(dp[i + 1][j1], _val);
				/*
				for (int j2 = 1; j2 <= cnt_col; ++j2) {
					if (j2 == j1 || j2 == c[i] || j2 == c[i + 1])
						continue;
					add(dp[i + 1][j2], _val);
				}
				*/
			}
			
			if (j1 <= cnt_col && p[i] + 2 < p[i + 1] && j1 != c[i + 1]) {
				// j2 == j1
				// j2 必 <= cnt_col
				add(dp[i + 1][j1], (ll)dp[i][j1] * calc(id_len[i], c[i], j1, c[i + 1], j1) % MOD);
			}
			
			if (c[i] != c[i + 1]) {
				// j2 == c[i]
				// j2 必 != j1
				// j2 必 <= cnt_col
				add(dp[i + 1][c[i]], (ll)dp[i][j1] * calc(id_len[i], c[i], j1, c[i + 1], c[i]) % MOD);
			}
			
			if (cnt_col < K) {
				// j2 == cnt_col + 1
				
				if (j1 == cnt_col + 1) {
					// 1. j2 实际等于 j1
					add(dp[i + 1][cnt_col + 1], (ll)dp[i][j1] * calc(id_len[i], c[i], j1, c[i + 1], j1) % MOD);
					// 2. j2 实际不等于 j1
					if (K - cnt_col >= 2)
						add(dp[i + 1][cnt_col + 1], (ll)dp[i][j1] * calc(id_len[i], c[i], j1, c[i + 1], j1 + 1) % MOD * (K - cnt_col - 1) % MOD);
				} else {
					add(dp[i + 1][cnt_col + 1], (ll)dp[i][j1] * calc(id_len[i], c[i], j1, c[i + 1], cnt_col + 1) % MOD * (K - cnt_col) % MOD);
				}
			}
			
			/*
			for (int j2 = 1; j2 <= cnt_col + 1; ++j2) {
				if (j2 == c[i + 1])
					continue;
				if (p[i] + 2 == p[i + 1] && j1 <= cnt_col && j2 <= cnt_col && j1 == j2)
					continue;
				
				if (j1 == cnt_col + 1 && j2 == cnt_col + 1) {
					// 1. j2 实际等于 j1
					
					if (K - cnt_col >= 1)
						add(dp[i + 1][j2], (ll)dp[i][j1] * calc(p[i + 1] - p[i] + 2, c[i], j1, c[i + 1], j2) % MOD);
					
					// 2. j2 实际不等于 j1
					
					if (K - cnt_col >= 2)
						add(dp[i + 1][j2], (ll)dp[i][j1] * calc(p[i + 1] - p[i] + 2, c[i], j1, c[i + 1], j2 + 1) % MOD * (K - cnt_col - 1) % MOD);
					
					continue;
				}
				
				
				int w = (j2 == cnt_col + 1 ? K - cnt_col : 1);
				add(dp[i + 1][j2], (ll)dp[i][j1] * calc(p[i + 1] - p[i] + 2, c[i], j1, c[i + 1], j2) % MOD * w % MOD);
			}
			*/
		}
		for (int j2 = 1; j2 <= cnt_col; ++j2) {
			if (j2 == c[i] || j2 == c[i + 1])
				continue;
			add(dp[i + 1][j2], val);
		}
	}
	
	int ans = 0;
	int tmp = pow_mod(K - 2, L - p[n] - 1);
	for (int j = 1; j <= cnt_col + 1; ++j) {
		add(ans, (ll)dp[n][j] * tmp % MOD);
	}
	cout << ans << endl;
}
int main() {
	int T; cin >> T;
	for (int t = 1; t <= T; ++t) {
		cout << "Case #" << t << ": ";
		cerr << endl;
		solve_case();
	}
	return 0;
}
posted @ 2021-01-20 11:39  duyiblue  阅读(96)  评论(1编辑  收藏  举报