【笔记】集合幂级数 2

【笔记】集合幂级数 2

参考资料

Optimal Algorithm on Polynomial Composite Set Power Series - Codeforces

feat(math/poly/sps.md): 增加集合幂级数 by hly1204 · Pull Request #5438 · OI-wiki/OI-wiki

【模板】子集卷积 / 集合幂级数 1 - caijianhong - 博客园

定义

想象有一个 \(n\) 元集合幂级数 \(A\),它的形状是

\[A(x)=\sum_{S\subseteq [n]}a_Sx^S \]

为了接下来的工作顺利进行,我们将其改写为 \(n\) 元的形式幂级数的形式,这样能体现每个元的独立性。

\[A(x_1, x_2, \cdots, x_n)=a_0+a_1x_1+a_2x_2+a_3x_1x_2+\cdots \]

在这个 \(n\) 元的形式幂级数上定义乘法为子集卷积,也就是说这个形式幂级数中只要有一个元的次数 \(\geq 2\),我们就不要了,和一般的集合意义相同。

实际上我们所做的事情就是将 \(A\) 这个集合幂级数扔进了

\[R[[x_1, x_2, \cdots, x_n]]/(x_1^2, x_2^2, \cdots, x_n^2) \]

这个形式幂级数环上,乘法也用这个环的定义。

子集卷积(形式幂级数乘法)

【模板】子集卷积 / 集合幂级数 1 - caijianhong - 博客园

集合幂级数运算

有两种方式进行集合幂级数之间运算:

基于 FWT 和子集卷积

【模板】子集卷积 / 集合幂级数 1 - caijianhong - 博客园

基于逐点牛顿迭代法

举个例子,假设要计算 \(F=\exp G\),那么对两边求对 \(x_n\) 的偏导数得 \(F'=FG'\)。这时候将后面的式子取 \([x_n^0]\) 的部分,因为 \([x_n^0]F'=[x_n^1]F\),所以

\[[x_n^1]F=[x_n^0]F\times [x_n^1]G \]

如果先递归求出 \([x_n^0]F\),那么就可以用一次子集卷积求出 \([x_n^1]F\),这样就解决了原问题。也就是说复杂度是 \(T(n)=T(n-1)+O(2^nn^2)=O(2^nn^2)\)

也有另外一种理解方式,我们提取其中的某个元,使 \(A=a+bx\) 其中 \(a, b\) 是两个集合幂级数,然后

\[\exp(a+bx)=\exp(a)\exp(bx)\equiv \exp(a)(1+bx)\equiv \exp(a)+\exp(a)bx\pmod {x^2} \]

先递归求出 \(\exp(a)\),就能求出 \(\exp(a+bx)\) 了,可以发现两种方法是一样的,但是第一种方法更加机械。

对于更加一般的,计算 \(F=f(G)\) 的任务,其中 \(f\) 是多项式或者乘法逆、对数、指数之类的性质好的函数。首先两边求偏导 \(F'=f'(G)G'\),然后也是提取 \([x_n^0]\) 的部分,将问题转化为求 \([x_n^0]f(G)\)\([x_n^0]f'(G)\)。如果没有意外的话,\(x_n\) 这一维就可以消去了,这样问题规模就缩小了。

由于对函数 \(f\) 的求导与 \(G\) 无关,将重复的问题合并,可以发现问题规模为 \(n-k\) 的子集卷积只用做 \(k+1\) 次,复杂度为 \(\sum_{k=0}^nk\cdot O(2^{n-k}(n-k)^2)\)。由于 \(\sum_{k\geq 0}k2^{-k}\) 收敛(为 \(2\)),所以复杂度还是 \(O(2^nn^2)\),但是显而易见的常数大。

比较

可以发现逐点牛顿迭代法在推导上有好处,它的时间复杂度也是固定的 \(O(2^nn^2)\),但只要复合 \(f\) 这一操作是可以 \(O(n^2)\) 的,这个算法就立即败下阵来。

例如 \(f\)\(\exp,\log\) 之类的,计算 \(f(g(x))\) 就可以直接求导并比较两侧系数,\(O(n^2)\) 解决。但如果 \(f\) 是多项式,那么 \(O(n^2)\) 的复合需要写 NTT,反而没有优势。

code

以下是集合幂级数的乘法逆和 \(\log\) 的代码实现。注意由于求多项式的乘法逆和对数(的截断)都可以 \(O(n^2)\),导致此算法相较于正常写法常数较大。会慢常数倍,大概 \(1.5\sim 3\) 倍。

另外也可以参考一下 Statistics #154 - LibreOJ 这里面最优解前几个的代码,都是逐点牛顿迭代法。

void fwt(mint vec[], int len, int op) {/*{{{*/
  if (op == +1) {
    for (int i = 1; i < len; i <<= 1) {
      for (int S = i; S < len; S = (S + 1) | i) vec[S] += vec[S ^ i];
    }
  } else {
    for (int i = 1; i < len; i <<= 1) {
      for (int S = i; S < len; S = (S + 1) | i) vec[S] -= vec[S ^ i];
    }
  }
}/*}}}*/
void subset_conv(mint lhs[], mint rhs[], mint res[], int len) {/*{{{*/
  int n = bitctz(len);
  vector<vector<mint>> f(n + 1, vector<mint>(1 << n));
  vector<vector<mint>> g(n + 1, vector<mint>(1 << n));
  for (int i = 0; i < 1 << n; i++) f[popc(i)][i] = lhs[i];
  for (int i = 0; i < 1 << n; i++) g[popc(i)][i] = rhs[i];
  for (int i = 0; i <= n; i++) fwt(f[i].data(), 1 << n, +1);
  for (int i = 0; i <= n; i++) fwt(g[i].data(), 1 << n, +1);
  vector<vector<mint>> h(n + 1, vector<mint>(1 << n));
  for (int i = 0; i <= n; i++) {
    for (int j = 0; j <= i; j++) {
      for (int k = 0; k < 1 << n; k++) h[i][k] += f[j][k] * g[i - j][k];
    }
  }
  for (int i = 0; i <= n; i++) fwt(h[i].data(), 1 << n, -1);
  for (int i = 0; i < 1 << n; i++) res[i] = h[popc(i)][i];
}/*}}}*/
void subset_conv3(mint lhs[], mint rhs[], mint mhs[], mint res[], int len) {/*{{{*/
  int n = bitctz(len);
  vector<vector<mint>> f(n + 1, vector<mint>(1 << n));
  vector<vector<mint>> g(n + 1, vector<mint>(1 << n));
  vector<vector<mint>> m(n + 1, vector<mint>(1 << n));
  for (int i = 0; i < 1 << n; i++) f[popc(i)][i] = lhs[i];
  for (int i = 0; i < 1 << n; i++) g[popc(i)][i] = rhs[i];
  for (int i = 0; i < 1 << n; i++) m[popc(i)][i] = mhs[i];
  for (int i = 0; i <= n; i++) fwt(f[i].data(), 1 << n, +1), fwt(g[i].data(), 1 << n, +1);
  for (int i = 0; i <= n; i++) fwt(m[i].data(), 1 << n, +1);
  vector<vector<mint>> h(n + 1, vector<mint>(1 << n));
  for (int i = 0; i <= n; i++) {
    for (int j = 0; j <= i; j++) {
      for (int k = 0; k < 1 << n; k++) h[i][k] += f[j][k] * g[i - j][k];
    }
  }
  swap(h, f);
  for (int i = 0; i <= n; i++) memset(h[i].data(), 0, sizeof(mint) << n);
  for (int i = 0; i <= n; i++) {
    for (int j = 0; j <= i; j++) {
      for (int k = 0; k < 1 << n; k++) h[i][k] += f[j][k] * m[i - j][k];
    }
  }
  for (int i = 0; i <= n; i++) fwt(h[i].data(), 1 << n, -1);
  for (int i = 0; i < 1 << n; i++) res[i] = h[popc(i)][i];
}/*}}}*/
void sps_inv(mint vec[], mint res[], int len) {
  int n = bitctz(len);
  res[0] = 1 / vec[0];
  for (int i = 0; i < n; i++) {
    //  subset_conv(vec + (1 << i), res, res + (1 << i), 1 << i);
    //  subset_conv(res + (1 << i), res, res + (1 << i), 1 << i);
    subset_conv3(vec + (1 << i), res, res, res + (1 << i), 1 << i);
    for (int j = 1 << i; j < 2 << i; j++) res[j] *= -1;
  }
}
void sps_log(mint vec[], mint res[], int len) {
  int n = bitctz(len);
  assert(vec[0] == 1);
  res[0] = 0;
  vector<mint> inv(len);
  sps_inv(vec, inv.data(), len);
  for (int i = 0; i < n; i++) {
    subset_conv(vec + (1 << i), inv.data(), res + (1 << i), 1 << i);
  }
}
posted @ 2025-03-12 16:39  caijianhong  阅读(178)  评论(0)    收藏  举报