Luogu13275 [NOI2025] 集合【容斥原理, 集合幂级数】

退役好久没做题所以也没更新. 最近放假在家摆烂摸鱼, 刷 b 站看到某位主播 vp 今年 NOI, 发现还有这种魔怔数数题, 就来做着玩玩. 训 AtCoder 取模计数题的人会有好报的. 😊

给定正整数 \(n\) 和序列 \(a_0,\ldots,a_{2^n-1}\in\mathbb F_p\), 对 \(S\subseteq [0,2^n)\cap\mathbb Z\) 定义 \(\mathrm{And}(S)\)\(S\) 中所有数的按位与, 要求:

\[\sum_{S,T\subseteq[0,2^n)}\mathbf 1\{S\cap T=\varnothing\}\cdot\mathbf 1\{\mathrm{And}(S)=\mathrm{And}(T)\}\prod_{i\in S\cup T}a_i. \]

数据范围: \(n\le 20\)\(p=998\,244\,353\), 其中部分测试点满足 \(a_i\ne -1\).

考虑对 \(\mathbf 1\{\mathrm{And}(S)=\mathrm{And}(T)\}\) 容斥: 枚举 \(x,y\in[0,2^n)\), 设 \(S_x\) 表示包含 \(x\) 的数的集合, 条件 \(x\subseteq\mathrm{And}(S)\)\(y\subseteq\mathrm{And}(T)\) 等价于 \(S\subseteq S_x\)\(T\subseteq S_y\), 又因为

\[\mathbf 1\{\mathrm{And}(S)=\mathrm{And}(T)\}=\sum_{x\subseteq\mathrm{And}(S)}\sum_{y\subseteq\mathrm{And}(T)}(-1)^{\mathrm{pc}(x)+\mathrm{pc}(y)}\cdot 2^{\mathrm{pc}(x\&y)}, \]

代回原式即得容斥系数是 \((-1)^{\mathrm{pc}(x)+\mathrm{pc}(y)}\cdot 2^{\mathrm{pc}(x\&y)}\), 贡献是 \(\sum_{S\subseteq S_x}\sum_{T\subseteq S_y}\mathbf 1\{S\cap T=\varnothing\}\prod_{i\in S\cup T}a_i\).

这里 \(\mathrm{pc}(x)\) 表示 \(x\) 的二进制表示中 1 的数量. 容斥系数怎么算的? 将条件表示为每个二进制位的 "相等" 条件相乘, 权值 \(((1,0),(0,1))\) 作差分得到 \(((1,-1),(-1,2))\), 按分配律全部展开就好啦.

贡献式里每个元素的方案独立: \(S_x\cap S_y=S_{x|y}\) 的元素可以放入 \(S\)\(T\) 或不放, 方案数为 \(1+2a_i\); 其他 \((S_x\cup S_y)\setminus S_{x|y}\) 的元素则是 \(1+a_i\).

又因为 \(\mathrm{pc}(x|y)=\mathrm{pc}(x)+\mathrm{pc}(y)-\mathrm{pc}(x\&y)\), 我们将其化为 OR 卷积的形式: 设 \(s_x:=\prod_{i\supseteq x}(1+a_i)\), \(t_x:=\prod_{i\supseteq x}\frac{1+2a_i}{(1+a_i)^2}\), 答案是 \((-2)^{\mathrm{pc}(x)}s_x\) 卷积自己, 逐项乘上 \(2^{-\mathrm{pc}(x)}t_x\) 再求和.

如果有 \(a_i=-1\) 咋办呢? 我们要把 \(s_x\)\(s_y\) 中属于 \(t_{x|y}\) 的 0 因子剔除. 考虑 FMT 的过程, 什么情况下 \(s_xs_y\) 的卷积项会贡献到 \(z\) 的位置? 只有 \(z\) 包含 \(x\)\(y\) 的情况. 此时 \(t_z\) 的分母必定整除 \(s_xs_y\), 即求和的每一项都被 \(t_z\) 的分母整除. 所以只需要维护 0 次数最低的一项, 就算有 \(a+(-a)\) 相消也不用关心更高次项的系数, 肯定是用不到的.

随便看了一下其他题解, 感觉这部分讲得比较模糊, 所以在洛谷也传了一份.

总结一下, 用后缀和/积计算 \(s_x\)\(t_x\) 的 0 次数和系数, 再跑一个 OR 卷积就完成了, 时间复杂度 \(O(n2^n)\).

#include <bits/stdc++.h>
#define fi first
#define se second
#define pc __builtin_popcount
typedef std::pair<int, int> PII; // coefficient and degree of 0
typedef long long LL;

const int mod = 998244353;
int qmo(int x) { return x + (x >> 31 & mod); }
int ksm(int a, int b) {
  int res = 1;
  for (; b; b >>= 1, a = (LL)a * a % mod)
    if (b & 1) res = (LL)res * a % mod;
  return res;
}

PII operator + (const PII &a, const PII &b) {
  if (a.se == b.se) return PII(qmo(a.fi + b.fi - mod), a.se);
  return a.se < b.se ? a : b;
}
PII operator += (PII &a, const PII &b) { a = a + b; return a; }

PII operator - (const PII &a) { return PII(qmo(-a.fi), a.se); }
PII operator - (const PII &a, const PII &b) { return a + (-b); }
PII operator -= (PII &a, const PII &b) { a = a - b; return a; }

PII operator * (const PII &a, int b) {
  if (b == 0) return PII(a.fi, a.se + 1);
  return PII((LL)a.fi * b % mod, a.se);
}
PII operator *= (PII &a, int b) { a = a * b; return a; }

PII operator * (const PII &a, const PII &b) {
  return PII((LL)a.fi * b.fi % mod, a.se + b.se);
}
PII operator *= (PII &a, const PII &b) { a = a * b; return a; }

void solve() {
  int n; std::cin >> n;
  std::vector<int> pwn2(n+1), pwi2(n+1); // power of -2 and 1/2
  pwn2[0] = pwi2[0] = 1;
  for (int i = 1; i <= n; ++i) {
    pwn2[i] = pwn2[i-1] * (mod - 2ll) % mod;
    pwi2[i] = (pwi2[i-1] + (pwi2[i-1] & 1) * mod) / 2;
  }

  int N = 1 << n; std::vector<PII> s(N), t(N);
  for (int i = 0, x; i < N; ++i) {
    std::cin >> x;
    if (x == mod - 1) {
      s[i] = PII(1, 1);
      t[i] = PII(mod - 1, -2);
    } else {
      s[i] = PII(1 + x, 0);
      t[i] = PII((1 + 2ll * x) * ksm(1 + x, mod - 3) % mod, 0);
    }
  }

  for (int i = 0; i < n; ++i)
    for (int j = 0; j < N; ++j)
      if (!(j & (1 << i))) {
        s[j] *= s[j ^ (1 << i)];
        t[j] *= t[j ^ (1 << i)];
      }

  for (int i = 0; i < N; ++i) {
    s[i] *= pwn2[pc(i)];
    t[i] *= pwi2[pc(i)];
  }

  for (int i = 0; i < n; ++i)
    for (int j = 0; j < N; ++j)
      if (j & (1 << i)) s[j] += s[j ^ (1 << i)];
  for (int i = 0; i < N; ++i) s[i] *= s[i];
  for (int i = 0; i < n; ++i)
    for (int j = 0; j < N; ++j)
      if (j & (1 << i)) s[j] -= s[j ^ (1 << i)];

  int ans = 0;
  for (int i = 0; i < N; ++i) {
    PII cur = s[i] * t[i]; assert(cur.se >= 0);
    if (cur.se == 0) ans = qmo(ans + cur.fi - mod);
  }
  printf("%d\n", ans);
}

int main() {
  int _, t;
  std::ios::sync_with_stdio(false);
  std::cin >> _ >> t;
  while (t --) { solve(); }
}
posted @ 2025-07-23 09:58  mizu164  阅读(29)  评论(0)    收藏  举报