取模(mod.cpp)
取模(mod.cpp) - 题解
考试的时候死活想不出来,考后也是经过 CQL 扑朔迷离的讲解才懂,下面介绍一下这个扑朔迷离的思路。
思路
观察 $\sum_{i = 1}^{n}\sum_{j - i + 1}^{n} \lfloor \frac{a_{i} + a_{j}}{2} \rfloor $,发现如果去掉 \(\frac{\cdots}{2}\) 和 $\lfloor \rfloor $ 会更简单,考虑先计算 \(\sum_{i = 1}^{n} \sum_{j - i + 1} ^ {n} (a_{i} + a_{j})\)。
总的贡献
设总的贡献为 \(ans\),易得:
-
有 \(m ^ n\) 种序列;
-
每个序列有 \(n\) 个数;
-
对于一个序列,每个数在这个序列中贡献了 \(n - 1\) 次;
所以所有数的贡献总次数为 \(m ^ n n(n - 1)\) 次(注意:这里是贡献的次数)。
因为所有数出现的概率是相等的,所以每个数贡献了 \(m ^ {n - 1} n(n - 1)\) 次。
所以所有数贡献的总和 \(ans\) 为:
\[\begin{align*}
ans &= m ^ {n - 1} n(n - 1) \times \frac{m(m + 1)}{2}\\
&= \frac{m ^ n n (n - 1)(m + 1)}{2}
\end{align*}
\]
要减去的贡献
观察 \(n\) 范围(\(n \le 10^6\)),可以 \(O(n)\) 枚举一个序列中奇数的个数。(下面设 \(m\) 个数中奇数的个数为 \(t_1\),偶数的个数为 \(t_0\))
设当前枚举到的奇数的个数为 \(x\),则当前偶数的个数为 \(n - x\),易得:
- 可以在 \(t1\) 个奇数中随机选出 \(x\) 个,有 \(t1 ^ x\) 种;
- 可以在 \(t0\) 个偶数数中随机选出 \(n - x\) 个,有 \(t0 ^ {n - x}\) 种;
- 把 \(x\) 个奇数随机放在 \(n\) 个空格里(或把 \(n - x\) 个奇数随机放在 \(n\) 个空格里),有 \(C_{n}^{x}\) 或 \(C_{n}^{n - x}\)(两者等价)种方案;
- 每一对奇数和偶数会减去 \(1\) 的贡献, 共有 \(x \times (n - x)\) 对。
所以:
\[\begin{align*}
ans &= ans - \sum_{i = 1} ^ {n - 1}(t1 ^ x \times t0 ^ {n - x} \times C_{n}^{x} \times x(n - x))\\
&= ans - \sum_{i = 1} ^ {n - 1}(t1 ^ x t0 ^ {n - x} C_{n}^{x} \cdot x(n - x))
\end{align*}
\]
枚举 \(x\) 即可。
代码:
#include <bits/stdc++.h>
#define ll long long
#define MOD 998244353
#define Maxn 1000005
using namespace std;
ll n, m, ans = 0, t0, t1;
ll f[Maxn], g[Maxn], fg[Maxn];
ll qpow(ll a, ll b) {
ll res = 1;
while (b) {
if (b & 1) { res = res * a % MOD; }
a = a * a % MOD;
b >>= 1;
} return res;
}
void init() {
f[0] = g[0] = fg[0] = 1;
for (ll i = 1; i <= Maxn - 2; i ++) {
f[i] = f[i - 1] * i % MOD;
g[i] = qpow(i, MOD - 2) % MOD;
fg[i] = fg[i - 1] * g[i] % MOD;
}
}
ll C(ll n, ll m) {
if (m > n) { return 0; }
if (m == 0) { return 1; }
return ((f[n] * fg[n - m] % MOD) * fg[m] % MOD) % MOD;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cin >> n >> m, init();
ans = qpow(m, n) * n % MOD * (n - 1) % MOD * (m + 1) % MOD * g[2] % MOD;
// cout << ans << "\n";
t0 = m / 2, t1 = t0 + (m % 2 == 1);
for (int i = 1; i < n; i ++) {
ans = (ans - (qpow(t0, i) * qpow(t1, n - i) % MOD * C(n, i) % MOD * i % MOD * (n - i) % MOD) % MOD + MOD) % MOD;
} cout << ans * g[2] % MOD;
return 0;
}

浙公网安备 33010602011771号