取模(mod.cpp)

取模(mod.cpp) - 题解

考试的时候死活想不出来,考后也是经过 CQL 扑朔迷离的讲解才懂,下面介绍一下这个扑朔迷离的思路。

思路

观察 $\sum_{i = 1}^{n}\sum_{j - i + 1}^{n} \lfloor \frac{a_{i} + a_{j}}{2} \rfloor $,发现如果去掉 \(\frac{\cdots}{2}\) 和 $\lfloor \rfloor $ 会更简单,考虑先计算 \(\sum_{i = 1}^{n} \sum_{j - i + 1} ^ {n} (a_{i} + a_{j})\)

总的贡献

设总的贡献为 \(ans\),易得:

  • \(m ^ n\) 种序列;

  • 每个序列有 \(n\) 个数;

  • 对于一个序列,每个数在这个序列中贡献了 \(n - 1\) 次;

所以所有数的贡献总次数为 \(m ^ n n(n - 1)\) 次(注意:这里是贡献的次数)。

因为所有数出现的概率是相等的,所以每个数贡献了 \(m ^ {n - 1} n(n - 1)\)​ 次。

所以所有数贡献的总和 \(ans\) 为:

\[\begin{align*} ans &= m ^ {n - 1} n(n - 1) \times \frac{m(m + 1)}{2}\\ &= \frac{m ^ n n (n - 1)(m + 1)}{2} \end{align*} \]

要减去的贡献

观察 \(n\) 范围(\(n \le 10^6\)),可以 \(O(n)\) 枚举一个序列中奇数的个数。(下面设 \(m\) 个数中奇数的个数为 \(t_1\),偶数的个数为 \(t_0\)

设当前枚举到的奇数的个数为 \(x\),则当前偶数的个数为 \(n - x\),易得:

  • 可以在 \(t1\) 个奇数中随机选出 \(x\) 个,有 \(t1 ^ x\) 种;
  • 可以在 \(t0\) 个偶数数中随机选出 \(n - x\) 个,有 \(t0 ^ {n - x}\) 种;
  • \(x\) 个奇数随机放在 \(n\) 个空格里(或把 \(n - x\) 个奇数随机放在 \(n\) 个空格里),有 \(C_{n}^{x}\)\(C_{n}^{n - x}\)(两者等价)种方案;
  • 每一对奇数和偶数会减去 \(1\) 的贡献, 共有 \(x \times (n - x)\) 对。

所以:

\[\begin{align*} ans &= ans - \sum_{i = 1} ^ {n - 1}(t1 ^ x \times t0 ^ {n - x} \times C_{n}^{x} \times x(n - x))\\ &= ans - \sum_{i = 1} ^ {n - 1}(t1 ^ x t0 ^ {n - x} C_{n}^{x} \cdot x(n - x)) \end{align*} \]

枚举 \(x\) 即可。

代码:

#include <bits/stdc++.h>
#define ll long long
#define MOD 998244353
#define Maxn 1000005

using namespace std;

ll n, m, ans = 0, t0, t1;
ll f[Maxn], g[Maxn], fg[Maxn];

ll qpow(ll a, ll b) {
	ll res = 1;
	while (b) {
		if (b & 1) { res = res * a % MOD; }
		a = a * a % MOD;
		b >>= 1;
	} return res;
}

void init() {
	f[0] = g[0] = fg[0] = 1;
	for (ll i = 1; i <= Maxn - 2; i ++) {
		f[i] = f[i - 1] * i % MOD;
		g[i] = qpow(i, MOD - 2) % MOD;
		fg[i] = fg[i - 1] * g[i] % MOD;
	}
}

ll C(ll n, ll m) {
	if (m > n) { return 0; }
	if (m == 0) { return 1; }
	return ((f[n] * fg[n - m] % MOD) * fg[m] % MOD) % MOD;
}

int main() {
	ios::sync_with_stdio(false);
	cin.tie(0);
	
	cin >> n >> m, init();
	ans = qpow(m, n) * n % MOD * (n - 1) % MOD * (m + 1) % MOD * g[2] % MOD;
//	cout << ans << "\n";
	t0 = m / 2, t1 = t0 + (m % 2 == 1);
	
	for (int i = 1; i < n; i ++) {
		ans = (ans - (qpow(t0, i) * qpow(t1, n - i) % MOD * C(n, i) % MOD * i % MOD * (n - i) % MOD) % MOD + MOD) % MOD;
	} cout << ans * g[2] % MOD;
	return 0;
}
posted @ 2024-08-07 16:27  BLM-dolphin  阅读(30)  评论(0)    收藏  举报