【笔记】集合幂级数 2
【笔记】集合幂级数 2
参考资料
Optimal Algorithm on Polynomial Composite Set Power Series - Codeforces
feat(math/poly/sps.md): 增加集合幂级数 by hly1204 · Pull Request #5438 · OI-wiki/OI-wiki
【模板】子集卷积 / 集合幂级数 1 - caijianhong - 博客园
定义
想象有一个 \(n\) 元集合幂级数 \(A\),它的形状是
为了接下来的工作顺利进行,我们将其改写为 \(n\) 元的形式幂级数的形式,这样能体现每个元的独立性。
在这个 \(n\) 元的形式幂级数上定义乘法为子集卷积,也就是说这个形式幂级数中只要有一个元的次数 \(\geq 2\),我们就不要了,和一般的集合意义相同。
实际上我们所做的事情就是将 \(A\) 这个集合幂级数扔进了
这个形式幂级数环上,乘法也用这个环的定义。
子集卷积(形式幂级数乘法)
见 【模板】子集卷积 / 集合幂级数 1 - caijianhong - 博客园
集合幂级数运算
有两种方式进行集合幂级数之间运算:
基于 FWT 和子集卷积
见 【模板】子集卷积 / 集合幂级数 1 - caijianhong - 博客园
基于逐点牛顿迭代法
举个例子,假设要计算 \(F=\exp G\),那么对两边求对 \(x_n\) 的偏导数得 \(F'=FG'\)。这时候将后面的式子取 \([x_n^0]\) 的部分,因为 \([x_n^0]F'=[x_n^1]F\),所以
如果先递归求出 \([x_n^0]F\),那么就可以用一次子集卷积求出 \([x_n^1]F\),这样就解决了原问题。也就是说复杂度是 \(T(n)=T(n-1)+O(2^nn^2)=O(2^nn^2)\)。
也有另外一种理解方式,我们提取其中的某个元,使 \(A=a+bx\) 其中 \(a, b\) 是两个集合幂级数,然后
先递归求出 \(\exp(a)\),就能求出 \(\exp(a+bx)\) 了,可以发现两种方法是一样的,但是第一种方法更加机械。
对于更加一般的,计算 \(F=f(G)\) 的任务,其中 \(f\) 是多项式或者乘法逆、对数、指数之类的性质好的函数。首先两边求偏导 \(F'=f'(G)G'\),然后也是提取 \([x_n^0]\) 的部分,将问题转化为求 \([x_n^0]f(G)\) 和 \([x_n^0]f'(G)\)。如果没有意外的话,\(x_n\) 这一维就可以消去了,这样问题规模就缩小了。
由于对函数 \(f\) 的求导与 \(G\) 无关,将重复的问题合并,可以发现问题规模为 \(n-k\) 的子集卷积只用做 \(k+1\) 次,复杂度为 \(\sum_{k=0}^nk\cdot O(2^{n-k}(n-k)^2)\)。由于 \(\sum_{k\geq 0}k2^{-k}\) 收敛(为 \(2\)),所以复杂度还是 \(O(2^nn^2)\),但是显而易见的常数大。
比较
可以发现逐点牛顿迭代法在推导上有好处,它的时间复杂度也是固定的 \(O(2^nn^2)\),但只要复合 \(f\) 这一操作是可以 \(O(n^2)\) 的,这个算法就立即败下阵来。
例如 \(f\) 为 \(\exp,\log\) 之类的,计算 \(f(g(x))\) 就可以直接求导并比较两侧系数,\(O(n^2)\) 解决。但如果 \(f\) 是多项式,那么 \(O(n^2)\) 的复合需要写 NTT,反而没有优势。
code
以下是集合幂级数的乘法逆和 \(\log\) 的代码实现。注意由于求多项式的乘法逆和对数(的截断)都可以 \(O(n^2)\),导致此算法相较于正常写法常数较大。会慢常数倍,大概 \(1.5\sim 3\) 倍。
另外也可以参考一下 Statistics #154 - LibreOJ 这里面最优解前几个的代码,都是逐点牛顿迭代法。
void fwt(mint vec[], int len, int op) {/*{{{*/
if (op == +1) {
for (int i = 1; i < len; i <<= 1) {
for (int S = i; S < len; S = (S + 1) | i) vec[S] += vec[S ^ i];
}
} else {
for (int i = 1; i < len; i <<= 1) {
for (int S = i; S < len; S = (S + 1) | i) vec[S] -= vec[S ^ i];
}
}
}/*}}}*/
void subset_conv(mint lhs[], mint rhs[], mint res[], int len) {/*{{{*/
int n = bitctz(len);
vector<vector<mint>> f(n + 1, vector<mint>(1 << n));
vector<vector<mint>> g(n + 1, vector<mint>(1 << n));
for (int i = 0; i < 1 << n; i++) f[popc(i)][i] = lhs[i];
for (int i = 0; i < 1 << n; i++) g[popc(i)][i] = rhs[i];
for (int i = 0; i <= n; i++) fwt(f[i].data(), 1 << n, +1);
for (int i = 0; i <= n; i++) fwt(g[i].data(), 1 << n, +1);
vector<vector<mint>> h(n + 1, vector<mint>(1 << n));
for (int i = 0; i <= n; i++) {
for (int j = 0; j <= i; j++) {
for (int k = 0; k < 1 << n; k++) h[i][k] += f[j][k] * g[i - j][k];
}
}
for (int i = 0; i <= n; i++) fwt(h[i].data(), 1 << n, -1);
for (int i = 0; i < 1 << n; i++) res[i] = h[popc(i)][i];
}/*}}}*/
void subset_conv3(mint lhs[], mint rhs[], mint mhs[], mint res[], int len) {/*{{{*/
int n = bitctz(len);
vector<vector<mint>> f(n + 1, vector<mint>(1 << n));
vector<vector<mint>> g(n + 1, vector<mint>(1 << n));
vector<vector<mint>> m(n + 1, vector<mint>(1 << n));
for (int i = 0; i < 1 << n; i++) f[popc(i)][i] = lhs[i];
for (int i = 0; i < 1 << n; i++) g[popc(i)][i] = rhs[i];
for (int i = 0; i < 1 << n; i++) m[popc(i)][i] = mhs[i];
for (int i = 0; i <= n; i++) fwt(f[i].data(), 1 << n, +1), fwt(g[i].data(), 1 << n, +1);
for (int i = 0; i <= n; i++) fwt(m[i].data(), 1 << n, +1);
vector<vector<mint>> h(n + 1, vector<mint>(1 << n));
for (int i = 0; i <= n; i++) {
for (int j = 0; j <= i; j++) {
for (int k = 0; k < 1 << n; k++) h[i][k] += f[j][k] * g[i - j][k];
}
}
swap(h, f);
for (int i = 0; i <= n; i++) memset(h[i].data(), 0, sizeof(mint) << n);
for (int i = 0; i <= n; i++) {
for (int j = 0; j <= i; j++) {
for (int k = 0; k < 1 << n; k++) h[i][k] += f[j][k] * m[i - j][k];
}
}
for (int i = 0; i <= n; i++) fwt(h[i].data(), 1 << n, -1);
for (int i = 0; i < 1 << n; i++) res[i] = h[popc(i)][i];
}/*}}}*/
void sps_inv(mint vec[], mint res[], int len) {
int n = bitctz(len);
res[0] = 1 / vec[0];
for (int i = 0; i < n; i++) {
// subset_conv(vec + (1 << i), res, res + (1 << i), 1 << i);
// subset_conv(res + (1 << i), res, res + (1 << i), 1 << i);
subset_conv3(vec + (1 << i), res, res, res + (1 << i), 1 << i);
for (int j = 1 << i; j < 2 << i; j++) res[j] *= -1;
}
}
void sps_log(mint vec[], mint res[], int len) {
int n = bitctz(len);
assert(vec[0] == 1);
res[0] = 0;
vector<mint> inv(len);
sps_inv(vec, inv.data(), len);
for (int i = 0; i < n; i++) {
subset_conv(vec + (1 << i), inv.data(), res + (1 << i), 1 << i);
}
}
本文来自博客园,作者:caijianhong,转载请注明原文链接:https://www.cnblogs.com/caijianhong/p/18767899
浙公网安备 33010602011771号