第 50 届 ICPC 国际大学生程序设计竞赛邀请赛武汉站 G 题题解

题目大意

给定一个 \(n \times m\) 的网格,要从 \((1, 1)\) 走到 \((n, m)\) 且每次只能向下或向右走,每个格子 \((i, j)\) 有个元素值 \(a_{i,j}\)。定义路径的价值为:路径上不同元素的个数,求所有路径的价值之和。

思路

枚举所有路径并记录路径上不同元素个数的方法明显不可以且极难编码,我们可以转换思路:求对于每种元素,有多少个不同路径方案会经过它,也即是每种元素会提供多少贡献。

对于某种元素的贡献可以想到两种方式解决:

  • 直接计算该元素的每个点提供的贡献;
  • 转换思路:某种元素的贡献 = 所有贡献和 - 所有不经过该元素的贡献和。

对于第一种方式:

当某元素只有一个点 \((i, j)\) 时,贡献为:从 \((1, 1)\)\((i, j)\) 的总方案 乘 从 \((i, j)\)\((n, m)\)的总方案。

而从 \((1, 1)\)\((i, j)\)的总方案,就是 \(C_{i + j - 2}^{i - 1}\)

为什么从 $(1, 1)$ 到 $(i, j)$的总方案,就是 $C_{i + j - 2}^{i - 1}$ ? 要从 $(1, 1)$ 到 $(i, j)$ 一共要进行 $i + j - 2$ 步,其中 $i - 1$ 步选择向下走,$j - 1$ 步选择向右走,在 $i + j - 2$ 步中确定 $i - 1$ 步向下,则剩下就是向右,所以总方案为 $C_{i + j - 2}^{i - 1}$。

当该元素有两个点 \((i, j)\)\((x, y)\) 时,假设 \((i, j)\)\((x, y)\) 之前,那么计算 \((x, y)\) 的贡献为:从 \((1, 1)\)\((i, j)\)的总方案 减去 \((i, j)\)\((x, y)\) 的方案数的差 乘 从 \((x, y)\)\((n, m)\)的总方案,这就是容斥原理。同理,当在点 \((x, y)\) 之前不止一个点时,要减去所有的重复的路径。

总结:
我们定义 \(f(i)\) 为:从 \((1, 1)\) 到第 \(i\) 个点 \((x_i, y_i)\) 且不经过同种元素的方案数,那么有公式:

\[f(i) = C_{x_i + y_i - 2}^{x_i - 1} - \sum_{j = 1}^{i - 1}(f(j) \times C_{x_i + y_i - x_j - y_j}^{x_i - x_j}) \]

该元素的贡献的公式为:

\[\sum_{i = 1}^{k}(f(i) \times C_{n + m - x_i - y _ i}^{n - x_i}) \]

这样,我们就可以枚举所有不同元素,计算出它们的贡献并求和,假设第 \(i\) 种元素的数量为 \(k\) ,那么时间复杂度就是 \(O(k^2)\),假设一共有
\(s\) 种不同元素,总时间复杂度就是 \(O(\sum_{i = 1}^{s}k_i^2)\)


对于第二种方式:

我们可以使用 \(dp\) 来解决所有不经过某个元素 \(x\) 的贡献,定义 \(dp\) 数组为 \(dp[i][j]\):表示从起点到 \((i, j)\) 且不经过元素 \(x\) 的方案数。

那么有转移方程:

\[dp[i][j] = \begin{cases} dp[i - 1][j] + dp[i][j - 1], \, \, a[i][j] \neq x \\ 0, \, \, \, \, \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ a[i][j] = x \end{cases} \]

计算完后,\(dp[n][m]\) 就是所有不经过某个元素 \(x\) 的贡献,所有贡献和可以使用第一种方式的组合数快速求出: \(C_{n + m - 2}^{n - 1}\)

所以,某个元素 \(x\) 的贡献 = \(C_{n + m - 2}^{n - 1} - dp[n][m]\)

然后,枚举所有不同元素求出它们的贡献并求和。

动态规划的时间复杂度为 \(O(nm)\),假设一共有 \(s\) 种不同元素,那么总时间复杂度为:\(O(s \times nm)\)


当我们分析极端情况时,两种方式的时间复杂度都会变为 \(O((nm)^2)\),一定会超时,这时就要讲本题的核心内容了:根号分治。

根号分治是对暴力算法的优化,设置基准值 \(B\) 为某个数据的开根,小于等于 \(B\) 的用一种方式,大于 \(B\) 的用另一种方式,从而使整体时间复杂度达到优化。

观察发现,第一种容斥原理,当不同元素少,同个元素数量多时,时间复杂度为恶化为 \(O((nm)^2)\),而第二种动态规划,当不同元素多,同个元素数量少时,时间复杂度为恶化为 \(O((nm)^2)\)

那我们就可以取长补短,同个元素数量少时用容斥原理,同个元素数量多时用动态规划,这个基准值可以取 \(\sqrt{nm}\),最终时间复杂度会稳定在 \(O(nm\sqrt{nm})\)

证明:

假设有 \(s\) 个元素满足数量 \(k \le \sqrt{nm}\),那么有

\[\sum_{i}^{s}{k_i} \le nm \]

每个元素做容斥原理时间复杂度为 \(O(k^2)\),总时间时间复杂度为

\[\sum_{i}^{s}{k_i^2} \]

其中 \(Max\{k_i\} \le \sqrt{nm}\),那么有:

\[\sum_{i}^{s}{k_i^2} \le Max\{k_i\} \times \sum_{i}^{s}k_i \le \sqrt{nm} \times nm \]

所以总时间复杂度为 \(O(nm\sqrt{nm})\)

对于数量 \(k > \sqrt{nm}\) 的元素,最多会有:\(\frac{nm}{\sqrt{nm}} = \sqrt{nm}\) 个,每个的时间复杂度为 \(O(nm)\),那么总时间复杂度为 \(O(nm\sqrt{nm})\)

所以运用根号分治的方法,我们可以将数量 \(\le \sqrt{nm}\) 的用容斥原理,将 \(> \sqrt{nm}\) 用动态规划。

参考代码

#include <iostream>
#include <vector>
#include <map>
#include <set>
#include <algorithm>
#include <cmath>

using namespace std;

typedef long long ll;
typedef pair<ll, ll> PII; 

const int mod = 998244353;
const int N = 1e6 + 10;

ll fact[N], infact[N];
int n, m;

ll qmi(ll a, ll b)
{
	ll res = 1;
	while (b) {
		if (b & 1) res = res * a % mod;
		a = a * a % mod;
		b >>= 1;
	}
	return res;
}

void init()
{
	fact[0] = infact[0] = 1;
	for (int i = 1; i < N; i ++) fact[i] = fact[i - 1] * i % mod;
	infact[N - 1] = qmi(fact[N - 1], mod - 2);
	for (int i = N - 2; i; i --) infact[i] = infact[i + 1] * (i + 1) % mod;
}

ll C(int a, int b)
{
	if (b < 0 || a - b < 0) return 0;
	return fact[a] * infact[b] % mod * infact[a - b] % mod;
}

bool cmp(PII a, PII b)
{
	if (a.first == b.first) return a.second < b.second;
	return a.first < b.first;
}

ll func1(vector<PII>& a, int len)
{
	vector<ll> f(len, 0);
	ll res = 0;
	for (int i = 0; i < len; i ++) {
		f[i] = C(a[i].first + a[i].second - 2, a[i].first - 1);
		for (int j = 0; j < i; j ++) {
			if (a[j].first <= a[i].first && a[j].second <= a[i].second) {
				f[i] = (f[i] - f[j] * C(a[i].first - a[j].first + a[i].second - a[j].second, a[i].first - a[j].first) % mod + mod) % mod;
			}
		}
		ll num = f[i] * C(n - a[i].first + m - a[i].second, n - a[i].first) % mod;
		res = (res + num) % mod;
	}
	
	return res;
}

ll func2(vector<vector<ll>>& g, ll tol, int c)
{
	vector<vector<ll>> dp(n + 1, vector<ll>(m + 1, 0));
	dp[1][1] = 1;
	ll res = 0;
	if (g[1][1] == c) dp[1][1] = 0;
	for (int i = 1; i <= n; i ++)
		for (int j = 1; j <= m; j ++) {
			if (i == 1 && j == 1) continue;
			if (g[i][j] == c) continue;
			dp[i][j] = (dp[i - 1][j] + dp[i][j - 1]) % mod;
		}
	res = (tol - dp[n][m] + mod) % mod;
	return res;
}

void solve()
{
	cin >> n >> m;
	vector<vector<ll>> g(n + 1, vector<ll>(m + 1));
	set<int> v;
	for (int i = 1; i <= n; i ++)
		for (int j = 1; j <= m; j ++) {
			cin >> g[i][j];
			v.insert(g[i][j]);
		}
		
	map<int, vector<PII>> c;
	for (int i = 1; i <= n; i ++)
		for (int j = 1; j <= m; j ++)
			c[g[i][j]].emplace_back(i, j);
			
	for (auto it : v) {
		sort(c[it].begin(), c[it].end(), cmp);
	}
	
	int mid = sqrt(n * m);
	ll tol = C(n + m - 2, n - 1);
	ll ans = 0;
	for (int it : v) {
		auto a = c[it];
		int len = a.size();
		
		ll x;
		if (len <= mid) x = func1(a, len);
		else x = func2(g, tol, it);
		ans = (ans + x) % mod;
	}
	cout << ans << '\n';
}

int main()
{
	ios::sync_with_stdio(false);
	cin.tie(0), cout.tie(0);
	
	init();
	
	int t;
	cin >> t;
	
	while (t --) solve();
	
	return 0;
}
posted @ 2025-05-25 18:14  Natural_TLP  阅读(242)  评论(0)    收藏  举报