[ARC124F] Chance Meeting 题解
容斥,容斥,容斥。
思路
考虑设 \(f_{i,j}\) 为第一次在 \(i,j\) 相遇的方案数。
发现答案为:
\[ans=\sum_{i=1}^n\sum_{j=1}^m f_{i,j}f_{n-i+1,m-j+1}
\]
考虑走到 \((i,j)\) 并相遇的方案数:
\[\binom{n+j+j-3}{i-1,j-1,n-i,j-1}
\]
但是这样有可能在前面就相遇了。
我们可以发现,相遇只会在一行发生,所以可以容斥。
有:
\[\begin{align}
f_{i,j}&=\binom{n+j+j-3}{i-1,j-1,n-i,j-1}-\sum_{k<j}f_{i,k}\binom{2\times j-2\times k}{j-k}\\
&=\frac{(n+j+j-3)!}{(i-1)!(j-1)!(n-i)!(j-1)!}-\sum_{k<j}f_{i,k}\frac{(j+j-k-k)!}{(j-k)!(j-k)!}
\end{align}
\]
套路地,令 \(f_{i,j}=\frac{g_{j}}{(i-1)!(n-i)!}\)。
有:
\[\begin{align}
g_{j}&=\frac{(n+j+j-3)!}{(j-1)!(j-1)!}-\sum_{k<j}g_{k}\frac{(j+j-k-k)!}{(j-k)!(j-k)!}
\end{align}
\]
那么答案为:
\[\begin{align}
ans&=\sum_{i=1}^n\sum_{j=1}^m f_{i,j}f_{n-i+1,m-j+1}\\
&=\sum_{i=1}^n \frac{1}{(i-1)!(n-i)!(n-i)!(i-1)!} \sum_{j=1}^m g_{j}g_{m-j+1}\\
\end{align}
\]
分治 ntt 处理即可。
时间复杂度:\(O(n\log^2 n)\)。
Code
#include <bits/stdc++.h>
#include "atcoder/convolution"
using namespace std;
const int mod = 998244353;
int n, m;
int f[600010];
int v[600010];
int g[200010];
int h[200010];
inline int power(int x, int y) {
int res = 1;
while (y) {
if (y & 1) res = 1ll * res * x % mod;
x = 1ll * x * x % mod, y >>= 1;
}
return res;
}
inline void init(int n) {
f[0] = 1;
for (int i = 1; i <= n; i++) f[i] = 1ll * i * f[i - 1] % mod;
v[n] = power(f[n], mod - 2);
for (int i = n; i >= 1; i--) v[i - 1] = 1ll * i * v[i] % mod;
}
inline void del(int&x, int y) {
x = x - y;
if (x < 0) x += mod;
}
inline void sol(int L, int R, int l, int r) {
vector<int> a, b;
int e = R - l;
for (int i = l; i <= r; i++) a.push_back(g[i]);
for (int i = 0; i <= e; i++) b.push_back(h[i]);
a = atcoder::convolution(a, b);
for (int i = L; i <= R; i++) {
del(g[i], a[e - R + i]);
}
}
inline void cdq(int l, int r) {
if (l < r) {
int mid = (l + r) >> 1;
cdq(l, mid);
sol(mid + 1, r, l, mid);
cdq(mid + 1, r);
}
}
int main() {
cin >> n >> m;
init(n + m + m);
for (int i = 1; i <= m; i++) {
g[i] = 1ll * f[n + i + i - 3] * v[i - 1] % mod * v[i - 1] % mod;
h[i] = 1ll * f[i + i] * v[i] % mod * v[i] % mod;
}
cdq(1, m);
int ans1 = 0;
int ans2 = 0;
for (int i = 1; i <= m; i++)
ans1 = (ans1 + 1ll * g[i] * g[m - i + 1]) % mod;
for (int i = 1; i <= n; i++)
ans2 = (ans2 + 1ll * v[i - 1] * v[n - i] % mod * v[i - 1] % mod * v[n - i]) % mod;
cout << 1ll * ans1 * ans2 % mod << "\n";
}

浙公网安备 33010602011771号