【AGC035F】Two Histograms

Problem

Description

你有一个 \(N\) 行、\(M\) 列的、每个格子都填写着 0 的表格。你进行了下面的操作:

  • 对于每一行 \(i\) ,选定自然数 \(r_i\)\(0\leq r_i\leq M\)),将这一行最左边的 \(r_i\) 个格子中的数 \(+1\).
  • 对于每一列 \(i\) ,选定自然数 \(c_i\)\(0\leq c_i\leq N\)),将这一列最上边的 \(c_i\) 个格子中的数 \(+1\).

这样,根据你选定的 \(r_1,r_2,\ldots,r_N,c_1,c_2,\ldots,c_M\) ,你就得到了一个每个格子要么是 0,要么是 1,要么是 2 的一个最终的表格。问本质不同的最终表格有多少种。两个表格本质不同当且进当它们有一个对应格子中的数不同。

Range

\(1\leq N,M \leq 5\cdot 10^5\)

Algorithm

容斥原理

Mentality

我们应该直接考虑重复的情况是怎么样的。

对于一对行和列,我们先假设其他行列的操作已经完成了,只需要考虑当前行列有多少种操作令结果不同。

然后缜密思索,我们发现只会有两个操作的结果是相同的。

假设我们正在考虑行 \(i\) 与列 \(j\) ,那么不难发现,只有当 \(r_i=j,c_j=i-1\)\(r_i=j-1,c_i=i\) 这两种情况时,它们的结果会相同,对于其他任意情况而言,结果唯一。

则我们只需要枚举有几对行列选择了这两种会重复的状态的前一种,剩下的随便填,然后利用容斥原理计算答案即可。

对于枚举 \(k\) ,则有:

\[f(k)=C^N_k*C^M_k*k!*(M+1)^{N-k}*(N+1)^{M-k} \]

则:

\[ans=\sum_{k=0}^{min(N,M)}(-1)^kf(k) \]

Code

#include <algorithm>
#include <cmath>
#include <complex>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <map>
#include <queue>
#include <set>
#include <vector>
using namespace std;
long long read() {
  long long x = 0, w = 1;
  char ch = getchar();
  while (!isdigit(ch)) w = ch == '-' ? -1 : 1, ch = getchar();
  while (isdigit(ch)) {
    x = (x << 3) + (x << 1) + ch - '0';
    ch = getchar();
  }
  return x * w;
}
const int Max_n = 5e5 + 5, mod = 998244353;
int n, m, ans;
int fac[Max_n], ifac[Max_n];
int f[Max_n];
int ksm(int a, int b) {
  int res = 1;
  for (; b; b >>= 1, a = 1ll * a * a % mod)
    if (b & 1) res = 1ll * res * a % mod;
  return res;
}
int C(int n, int m) { return 1ll * fac[n] * ifac[m] % mod * ifac[n - m] % mod; }
int main() {
#ifndef ONLINE_JUDGE
  freopen("F.in", "r", stdin);
  freopen("F.out", "w", stdout);
#endif
  n = read(), m = read();
  if (n > m) swap(n, m);
  fac[0] = ifac[0] = 1;
  for (int i = 1; i <= m; i++) fac[i] = 1ll * fac[i - 1] * i % mod;
  ifac[m] = ksm(fac[m], mod - 2);
  for (int i = m - 1; i; i--) ifac[i] = 1ll * ifac[i + 1] * (i + 1) % mod;
  for (int i = 0; i <= n; i++) {
    f[i] = 1ll * C(n, i) * C(m, i) % mod * fac[i] % mod;
    f[i] = 1ll * f[i] * ksm(m + 1, n - i) % mod * ksm(n + 1, m - i) % mod;
  }
  for (int i = 0; i <= n; i++)
    ans = ((ans + ksm(-1, i & 1) * f[i]) % mod + mod) % mod;
  cout << ans;
}
posted @ 2019-09-01 19:22  洛水·锦依卫  阅读(309)  评论(0编辑  收藏  举报