# 概述

$FWT$是用来处理集合卷积的问题。也就是求解$f(n)\sum\limits_{i|j=n}f(i)f(j)$类型的问题。其中或运算可以改为$\otimes,\&$

# 相互转化

void fwt_or(int *a,int xs) {
for(int i = 0;i < n;++i)
for(int j = 0;j < (1 << n);++j)
if(!((j >> i) & 1))
a[j | (1 << i)] += a[j];
}


# 板子

## 或运算

void fwt_or(int *a,int xs) {
for(int i = 0;i < n;++i)
for(int j = 0;j < (1 << n);++j)
if(!((j >> i) & 1))
a[j | (1 << i)] += xs * a[j];
}


## and运算

void fwt_and(int *a,int xs) {
for(int i = 0;i < n;++i)
for(int j = 0;j < (1 << n);++j)
if(!((j >> i) & 1))
a[j] += xs * a[j | (1 << i)];
}



## 异或运算

void fwt_xor(int *a,int xs) {
for(int i = 0;i < n;++i) {
for(int j = 0;j < (1 << n);++j) {
if(!((j >> i) & 1)) {
int l = a[j],r = a[j | (1 << i)];
a[j] = l + r;a[j] %= mod;
a[j | (1 << i)] = l - r;a[j | (1 << i)] %= mod;
}
}
}
if(xs == -1) {
int inv = qm(1 << n,mod - 2);
for(int i = 0;i < (1 << n);++i)
a[i] = 1ll * a[i] * inv % mod;
}
}



# 模板题

luogu4717

/*
* @Author: wxyww
* @Date:   2020-04-26 08:03:27
*/
#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<queue>
#include<vector>
#include<ctime>
using namespace std;
typedef long long ll;
const int N = 1 << 20,mod = 998244353;
ll x = 0,f = 1;char c = getchar();
while(c < '0' || c > '9') {
if(c == '-') f = -1; c = getchar();
}
while(c >= '0' && c <= '9') {
x = x * 10 + c - '0'; c = getchar();
}
return x * f;
}
int A[N],B[N],n;

void fwt_and(int *a,int xs) {
for(int i = 0;i < n;++i) {
for(int j = 0;j < (1 << n);++j) {
if(!((j >> i) & 1)) {
a[j] += xs * a[j | (1 << i)];
a[j] %= mod;
}
}
}
}
void fwt_or(int *a,int xs) {
for(int i = 0;i < n;++i) {
for(int j = 0;j < (1 << n);++j) {
if(!((j >> i) & 1)) {
a[j | (1 << i)] += xs * a[j];
a[j | (1 << i)] %= mod;
}
}
}
}

ll qm(ll x,ll y) {
ll ret = 1;
for(;y;y >>= 1,x = x * x % mod)
if(y & 1) ret = ret * x % mod;
return ret;
}
void fwt_xor(int *a,int xs) {
for(int i = 0;i < n;++i) {
for(int j = 0;j < (1 << n);++j) {
if(!((j >> i) & 1)) {
int l = a[j],r = a[j | (1 << i)];
a[j] = l + r;a[j] %= mod;
a[j | (1 << i)] = l - r;a[j | (1 << i)] %= mod;
}
}
}
if(xs == -1) {
int inv = qm(1 << n,mod - 2);
for(int i = 0;i < (1 << n);++i) {
a[i] = 1ll * a[i] * inv % mod;
}
}
}

int tmp1[N],tmp2[N];

int main() {
for(int i = 0;i < (1 << n);++i) A[i] = read();
for(int i = 0;i < (1 << n);++i) B[i] = read();

memcpy(tmp1,A,sizeof(tmp1));
memcpy(tmp2,B,sizeof(tmp2));
fwt_or(tmp1,1);fwt_or(tmp2,1);
for(int i = 0;i < (1 << n);++i) tmp1[i] = 1ll * tmp1[i] * tmp2[i] % mod;
fwt_or(tmp1,-1);
for(int i = 0;i < (1 << n);++i) printf("%d ",(tmp1[i] + mod) % mod);puts("");

memcpy(tmp1,A,sizeof(tmp1));
memcpy(tmp2,B,sizeof(tmp2));
fwt_and(tmp1,1);fwt_and(tmp2,1);
for(int i = 0;i < (1 << n);++i) tmp1[i] = 1ll * tmp1[i] * tmp2[i] % mod;
fwt_and(tmp1,-1);
for(int i = 0;i < (1 << n);++i) printf("%d ",(tmp1[i] + mod) % mod);puts("");

memcpy(tmp1,A,sizeof(tmp1));
memcpy(tmp2,B,sizeof(tmp2));
fwt_xor(tmp1,1);fwt_xor(tmp2,1);
for(int i = 0;i < (1 << n);++i) tmp1[i] = 1ll * tmp1[i] * tmp2[i] % mod;
fwt_xor(tmp1,-1);
for(int i = 0;i < (1 << n);++i) printf("%d ",(tmp1[i] + mod) % mod);puts("");
return 0;
}

posted @ 2020-04-26 10:24  wxyww  阅读(...)  评论(...编辑  收藏