FWT 学习笔记
FWT 部分
一:初步认识
FWT,即快速沃尔什变换,是用于解决对下标进行位运算卷积问题的方法。
时间复杂度为 \(O(n \log n)\)。
我们约定:
-
用大写字母表示一个多项式,例如 \(A\);
-
用 \(FWT(A)\) 表示 \(A\) 经过快速沃尔什变换后得到的幂级数(可以简单理解为多项式),并记 \(FWT(A)[i]\) 为该序列的第 \(i\) 项。
-
我们将用以下符号表示位运算:\(\land\)(按位与),\(\lor\)(按位或)和 \(\otimes\)(按位异或)。
-
记 \(a\) 的第 \(i\) 个二进制位上的值为 \(a_i\)。
二:概述
我们暂时不知道 FWT 应该满足什么要求,要不看一看 FFT?它有以下特性:
-
把运算转化为点积;
-
是线性变换;
第一点就是 \(FWT(A) \dot FWT(B)=FWT(C)\);
第二点就是 \(FWT(A)[i]=\sum_{j=0}^n c(i,j)A_j\),其中 \(c(i,j)\) 是 \(A_j\) 对 \(FWT(A)[i]\) 的贡献系数,也称为变换系数。
把第二点带入到第一点里推一下式子,就可以得到结论 \(c(i,j)c(i,k)=c(i,j \oplus k)\)。
同时,位运算有一个特点:可以分位考虑。因此有 \(c(i,j)=\prod_k c(i_k,j_k)\)。
那么知道了 \(c\) 后怎么求 \(FWT(A)\) 呢?
法一:\(FWT(A)[i]=\sum \limits_{j=0}^n c(i,j) A_j\)。最基本的 \(O(n^2)\) 做法。
法二:利用位运算可以分位计算的特点。观察上述式子,可以转化为:
注意到两个和式的 \(j\) 仅有最高位不同。记 \(i'\) 为 \(i\) 去掉最高二进制位后得到的数,\(i_0\) 为 \(i\) 最高二进制位上的值,那么将其提出可以得到:
那么这样就使原变换分成了两个规模减半的子变换,对每一个 \(i\) 做一次,时间复杂度 \(O(n^2\log n)\),怎么反而变劣了?注意到上述过程的瓶颈在于重复计算了很多次形如 \(\sum_j c(i,j)A[j]\) 的部分,这其实可以替换为 \(FWT(A')[i']\) 的。因此有一个更优的做法。
记 \(A_0\) 为幂级数 \(A\) 中首位为 \(0\) 的部分,类似的定义有 \(A_1\)。那么有:
当 \(i_0=0\) 时:
当 \(i_0=1\) 时:
在 \(n=1\) 时:可以认为 \(FWT(A)[i]=A_0\)。
已知幂级数的加法和数乘都是 \(O(n)\),那么至此遍实现了 \(O(n\log n)\) 的做法。(实际上在变化时 \(n=2^m\),那么复杂度也可以写为 \(O(m2^m)\))
那么 IFWT 的做法也很明晰了,就是把 \(c\) 换成 \(c^{-1}\) 即可。
三:具体运算
在概述部分已经将 \(FWT\) 的过程明确了,此时只剩下构造 \(c\) 矩阵这件事没有完成。那么这部分主要讲解怎么构造 \(c\) 矩阵。
\(c\) 矩阵也被称为位矩阵,通常情况下是 \(2\times 2\) 的,但是在 \(k\) 进制的位运算时,就是 \(k \times k\) 的了。
推导过程主要应用 \(c(i,j)c(i,k)=c(i,j\oplus k)\) 和位矩阵存在逆的性质。
1:或卷积
容易发现:
因为一行或一列全是 \(0\) 或有存在两行或两列相等的矩阵不存在逆,所以 \(c\) 只能为 \(\begin{bmatrix}1&1\\1&0\end{bmatrix}\) 或 \(\begin{bmatrix}1&0\\1&1\end{bmatrix}\)。第二个矩阵的实际意义是子集求和,因为它等价于 \(c(i,j)=[i \land j=j]=[j\subseteq i]\)。通常使用第二个矩阵。
第二个矩阵的逆为 \(\begin{bmatrix}1&0\\-1&1\end{bmatrix}\)。
2:与卷积
容易发现:
\(c\) 为 \(\begin{bmatrix}0&1\\1&1\end{bmatrix}\) 或 \(\begin{bmatrix}1&1\\0&1\end{bmatrix}\),还是采用第二种。
其逆为 \(\begin{bmatrix}1&-1\\0&1\end{bmatrix}\)。
3:异或卷积
容易发现:
\(c\) 为 \(\begin{bmatrix}1&1\\-1&1\end{bmatrix}\) 或 \(\begin{bmatrix}1&1\\1&-1\end{bmatrix}\),还是采用第二种。
其逆为 \(\begin{bmatrix}0.5&0.5\\0.5&-0.5\end{bmatrix}\)。
4:扩展
其实位运算的本质分别为:
- 或运算是对每一位分别取 \(\max\);
- 与运算是对每一位分别取 \(\min\);
- 异或运算是每一位相加后对 \(k\) 取模(在 \(k\) 进制下)
四:参考实现
模板题:【模板】快速莫比乌斯 / 沃尔什变换 (FMT / FWT)
其实对比一下 FFT,发现只是去掉了翻转二进制位这一步骤而已。
const int N=2e5+100,p=998244353,inv2=499122177;
int n,a[N],b[N];
ll f[N],g[N];
const ll Cor[2][2]={{1,0},{1,1}};
const ll inv_Cor[2][2]={{1,0},{p-1,1}};
const ll Cand[2][2]={{1,1},{0,1}};
const ll inv_Cand[2][2]={{1,p-1},{0,1}};
const ll Cxor[2][2]={{1,1},{1,p-1}};
const ll inv_Cxor[2][2]={{inv2,inv2},{inv2,p-inv2}};
void FWT(ll *f,const ll c[2][2]){
for(int len=1;len<n;len<<=1)
for(int j=0;j<n;j+=(len<<1))
for(int i=j;i<j+len;++i){
ll t=f[i];
f[i] =(t*c[0][0]+f[i+len]*c[0][1])%p;
f[i+len]=(t*c[1][0]+f[i+len]*c[1][1])%p;
}
}
void Mul(ll *f,ll *g,const ll c[2][2],const ll ic[2][2]){
FWT(f,c),FWT(g,c);
Down(i,n-1,0) f[i]=f[i]*g[i]%p;
FWT(f,ic);
}
void reset(){
Down(i,n-1,0) f[i]=a[i],g[i]=b[i];
}
void print(){
For(i,0,n-1) write(f[i]),putchar(' ');
putchar('\n');
}
int main()
{
n=read(),n=(1<<n);
For(i,0,n-1) a[i]=read();
For(i,0,n-1) b[i]=read();
reset();
Mul(f,g,Cor,inv_Cor);
print();
reset();
Mul(f,g,Cand,inv_Cand);
print();
reset();
Mul(f,g,Cxor,inv_Cxor);
print();
return 0;
}
子集卷积部分
一:概述
子集卷积形如:
看起来有点棘手,但是注意到如果 \(i \land j=0\),记 \(|i|\) 表示 \(i\) 在二进制下的 \(1\) 的个数,那么有 \(|i|+|j|=|k|\)。
因此可以记 \(A(k)[i]=[|i|=k]A[i]\),那么上述运算可以转化为:
其中 \(\oplus\) 运算是或卷积。最后还原到幂级数 \(C\) 就是 \(C[i]=C(|i|)[i]\)。
分析一下时间复杂度:或卷积是 \(O(m2^m)\) 的,外层枚举是 \(O(m)\) 的,因此时间复杂度为 \(O(2^mm^2)\)。
二:参考实现
模板题:【模板】子集卷积
const int N=(1<<20)+100,p=1e9+9;
int n,m,f[21][N],g[21][N],h[21][N],res[N],pc[N];
inline int bmod(int x){return x>=p ? x-p : x;}
void FWT(int *f,int flag){
for(int len=1;len<n;len<<=1)
for(int j=0;j<n;j+=(len<<1))
for(int i=j;i<j+len;++i)
f[i+len]=(flag ? bmod(f[i+len]+f[i]) : bmod(f[i+len]-f[i]+p));
}
void Mul(){
For(i,0,m) FWT(f[i],1),FWT(g[i],1);
For(i,0,m)
For(j,0,i)
Down(k,n-1,0)
h[i][k]=bmod(h[i][k]+1ll*f[j][k]*g[i-j][k]%p);
For(i,0,m) FWT(h[i],0);
}
int main()
{
m=read(),n=(1<<m);
For(i,0,n-1) pc[i]=__builtin_popcount(i);
For(i,0,n-1) f[pc[i]][i]=read();
For(i,0,n-1) g[pc[i]][i]=read();
Mul();
For(i,0,n-1) printf("%d ",h[pc[i]][i]);
return 0;
}

浙公网安备 33010602011771号