【hdu6507 Kanade's convolution】FWT
逛了广二一圈,食堂的量真不咋样,晚上果然又是被安排成了zhong庆坐在独特的二楼orz。
本题感谢hdhd的帮我找出了ntt做法的错误性orz
hdu6057
题意:
给定两个序列 A[0...2^m-1], B[0...2^m-1]
求 C[0...2^m-1] ,满足:
C[k] = ∑[i&j==k] A[i^j] * B[i|j]
m <= 19
考虑到一个性质i&j = i|j - i^j,那么如果我们枚举x = i|j , y = i^j,
那么有$C[k] = \sum\limits_{x}\sum\limits_{y} [x \ and \ y = y] [ x - y = k ] B[x] * A[y] * 2^{bit[y]} $
$C[k] = \sum\limits_{x}\sum\limits_{y} [x \ and \ y = y] [ x \ xor \ y = k ] B[x] * A[y] * 2^{bit[y]} $
$C[k] = \sum\limits_{x \ xor \ y=k} [x \ and \ y = y] B[x] * A[y] * 2^{bit[y]} $
$C[k] = \sum\limits_{x \ xor \ y=k} [bit(x) - bit(y) = bit(k) ] B[x] * A[y] * 2^{bit[y]} $
之后我们对于 bit个数不同的分别fwt卷积,然后对于不同的bit进行再对应卷积就可以了(说不清楚orz)
code:
#include<stdio.h>
#include<iostream>
#include<cstdio>
#define lowbit(x) ((x)&(-x))
using namespace std;
const int mod = 998244353;
int inv2;
int add(int x,int y) { x+=y; return x>=mod?x-mod:x; }
int sub(int x,int y) { x-=y; return x<0?x+mod:x; }
int mul(int x,int y) { return 1ll*x*y%mod; }
int ksm(int a,int b) {
int ans = 1;
for(;b;b>>=1,a=mul(a,a))
if(b&1) ans = mul(ans,a);
return ans;
}
int gbit(int x) {
int sm = 0;
for(;x;x-=lowbit(x)) sm++;
return sm;
}
void fwtxor(int *a,int s,int dft) {
for(int st=1;st<s;st<<=1) {
for(int i=0;i<s;i+=(st<<1)) {
for(int j=i;j<i+st;j++) {
int x = a[j]; int y = a[j+st];
a[j] = add(x,y); a[j+st] = sub(x,y);
if(dft==-1) a[j]=mul(a[j],inv2),a[j+st]=mul(a[j+st],inv2);
}
}
}
}
int n,S;
int bc[1<<20];
int A[20][1<<20],B[20][1<<20],C[20][1<<20];
int main() {
scanf("%d",&n);
inv2 = ksm(2,mod-2);
S = (1<<n)-1;
for(int i=0;i<=S;i++) bc[i] = gbit(i);
for(int i=0;i<=S;i++) {
int x; scanf("%d",&x);
A[bc[i]][i] = mul(1<<bc[i],x);
}
for(int i=0;i<=S;i++) {
int x; scanf("%d",&x);
B[bc[i]][i] = x;
}
for(int i=0;i<=n;i++) {
fwtxor(A[i],(1<<(n)),1);
fwtxor(B[i],1<<(n),1);
}
for(int k=0;k<=n;k++) {
for(int x=0;x<=k;x++) {
for(int i=0;i<=S;i++) {
C[x][i] = add(C[x][i],mul(B[k][i],A[k-x][i]));
}
}
}
for(int i=0;i<=n;i++) fwtxor(C[i],1<<(n),-1);
int ans = 0;
for(int i=0;i<=S;i++) {
ans = add(ans,mul(C[bc[i]][i],ksm(1526,i)));
}
printf("%d\n",ans);
}