FWT小记

前言

这是博主最后一年寒假时候学的,仅为了自己复习写的,所以不够详细。等有时间了大概会来补充完善一下。
这篇博客或许不错

解决问题

\(c_i=\sum\limits_{j\circ k=i}a_j\times b_k\)

核心思想

对于 \(a,b\) 找到一个可逆变换,使得可以将变换后的 \(a,b\) 直接点乘得到变换后的 \(c\) ,然后逆运算回来。

或卷积

\(f_i=\sum\limits_{j \ |\ i=i} a_j\)\(g_i=\sum\limits_{j \ |\ i=i} b_j\)\(h_i=f_i \times g_i\)

\(h_i=\sum\limits_{j\ |\ i = i} \ \sum\limits_{k\ |\ i=i} a_j\times b_k=\sum\limits_{(j\ |\ k)\ | \ i=i}a_j\times b_k\)

所以 \(h\) 就是 \(c\) 的变换后的数组。此时对 \(h\) 做一个高维差分即可得到 \(c\)

与卷积

\(f_i=\sum\limits_{j \ \&\ i=i} a_j\)\(g_i=\sum\limits_{j \ \&\ i=i} b_j\)\(h_i=f_i \times g_i\)

\(h\) 做一个高维后缀差分。

异或卷积

定义 \(F(x)=popcount(x) \mod 2\)

\(f_i=\sum\limits_{F(i\ \otimes j)=0}a_i-\sum\limits_{F(i\ \otimes j)=1}a_i\)\(g_i=\sum\limits_{F(i\ \otimes j)=0}a_i-\sum\limits_{F(i\ \otimes j)=1}b_i\)

可以暴力分类讨论得出 \(h_i=f_i \times g_i\)

实现方式

因为这类变换 位之间都是独立的,我们考虑类似于 \(DP\) 那样从低到高一位一位地去实现。

void fwtor(ll *f,int op){
	for(ri len = 2,h = 1;len <= lim;len <<= 1,h <<= 1)
		for(ri i = 0;i < lim;i += len)
			for(ri j = i;j < i + h;++j)
				f[j + h] += f[j] * op,f[j + h] %= mod;
}
void fwtand(ll *f,int op){
	for(ri len = 2,h = 1;len <= lim;len <<= 1,h <<= 1)
		for(ri i = 0;i < lim;i += len)
			for(ri j = i;j < i + h;++j)
				f[j] += f[j + h] * op,f[j] %= mod;
}

void fwtxor(ll *f,int op){
	for(ri len = 2,h = 1;len <= lim;len <<= 1,h <<= 1)
		for(ri i = 0;i < lim;i += len)
			for(ri j = i;j < i + h;++j){
				f[j] = (f[j] + f[j+h]) % mod;
				f[j+h] = ((f[j] - f[j+h] - f[j+h]) % mod + mod) % mod;
				f[j] = f[j] * op % mod;
				f[j+h] = f[j+h] * op % mod;
			}
}

子集卷积

\(c_i=\sum\limits_{j\ | \ k=i\ ,\ j \ \&\ k = i}a_j\times b_k\)

将条件转换,\(\sum\limits_{j\ | \ k=i\ ,\ F(j) + F(k) = F(j|k)}\)\(F\) 定义同上

\(F(i)\)\(a\) 分类 ,\(a'_{F(i),i} = a_i\) ,把 \(a',b'\) 求出 \(\text{fwt}\) ,然后手动卷那个 \(F(i)\) ,最后把 \(c'\)\(\text{fwt}\) 搞回 \(c'\)\(c_i=c'_{F(i),i}\)

//from 2022.2.3 11:40
#include<bits/stdc++.h>
#define ri register int
#define ll long long
using namespace std;
const int maxn = (1<<20) + 5,mod = 1e9 + 9;
inline int rd(){
	int res = 0,f = 0; char ch = getchar();
	for(;!isdigit(ch);ch = getchar()) if(ch == '-') f = 1;
	for(;isdigit(ch);ch = getchar()) res = (res<<3) + (res<<1) + ch - 48;
	return f ? -res : res;
}
int n;
inline void fwt(ll *f,int op){//or 卷积
	for(ri mid = 1;mid < (1<<n);mid <<= 1)
		for(ri l = 0,len = (mid<<1);l < (1<<n);l += len)
			for(ri i = 0;i < mid;++i)
				if(op == 1) f[l + mid + i] = (f[l + mid + i] + f[l + i]) % mod;
				else f[l + mid + i] = (f[l + mid + i] - f[l + i] + mod) % mod;
}
int pop[maxn];
ll f[21][maxn],g[21][maxn],h[21][maxn];
int main(){
	n = rd();
	for(ri i = 1;i < (1<<n);++i) pop[i] = pop[i>>1] + (i&1);
	for(ri i = 0;i < (1<<n);++i) f[pop[i]][i] = rd();
	for(ri i = 0;i < (1<<n);++i) g[pop[i]][i] = rd();
	for(ri i = 0;i <= n;++i) fwt(f[i],1),fwt(g[i],1);
	for(ri i = 0;i <= n;++i)
		for(ri j = 0;j <= i;++j)
			for(ri k = 0;k < (1<<n);++k)
				h[i][k] = (h[i][k] + f[j][k] * g[i-j][k] % mod) % mod;
	for(ri i = 0;i <= n;++i) fwt(h[i],-1);
	for(ri i = 0;i < (1<<n);++i) printf("%lld ",h[pop[i]][i]);
	puts("");
	return 0;
}
posted @ 2022-02-03 22:26  Lumos壹玖贰壹  阅读(38)  评论(0编辑  收藏  举报