【数学】快速Walsh变换
快速Walsh变换
给两个长度为 \(n\) 的序列 \(a,b\) ,满足 \(n=2^k\) ,序列标号为 \(a_0,a_1,\cdots,a_{n-1}\) ,
求AND卷积:
序列 \(c\) ,满足 \(c_i=\sum\limits_{(j\&k)=i}a_j\cdot b_k\)
求OR卷积:
序列 \(c\) ,满足 \(c_i=\sum\limits_{(j|k)=i}a_j\cdot b_k\)
求XOR卷积:
序列 \(c\) ,满足 \(c_i=\sum\limits_{(j\oplus k)=i}a_j\cdot b_k\)
复杂度 \(O(nlogn)\)
模数可以任意选,但是小心这个除以2是怎么实现,当模数是质数的时候可以用2的逆元。
namespace FWT {
    const int MOD = 998244353;
    const int INV2 = (MOD + 1) >> 1;
    inline int add(const int &x, const int &y) {
        int r = x + y;
        if(r >= MOD)
            r -= MOD;
        return r;
    }
    inline int sub(const int &x, const int &y) {
        int r = x - y;
        if(r < 0)
            r += MOD;
        return r;
    }
    inline int mul(const int &x, const int &y) {
        ll r = (ll)x * y;
        if(r >= MOD)
            r %= MOD;
        return (int)r;
    }
    /* op =
        +1 AND
        -1 IAND
        +2 OR
        -2 IOR
        +3 XOR
        -3 IXOR
     */
    void FWT(int *a, int n, int op) {
        for(int l = 1; l < n; l <<= 1) {
            for(int i = 0; i < n; i += (l << 1)) {
                for(int j = 0; j < l; ++j) {
                    int x = a[i + j], y = a[i + j + l];
                    switch(op) {
                    case +1:
                        a[i + j] = add(x, y);
                        break;
                    case +2:
                        a[i + j + l] = add(x, y);
                        break;
                    case +3:
                        a[i + j] = add(x, y);
                        a[i + j + l] = sub(x, y);
                        break;
                    case -1:
                        a[i + j] = sub(x, y);
                        break;
                    case -2:
                        a[i + j + l] = sub(y, x);
                        break;
                    case -3:
                        a[i + j] = mul(add(x, y), INV2);
                        a[i + j + l] = mul(sub(x, y), INV2);
                        break;
                    default:
                        exit(-1);
                    }
                }
            }
        }
    }
    /* op =
        +1 AND
        +2 OR
        +3 XOR
     */
    void Convolution(int *A, int *B, int n, int op) {
        assert(__builtin_popcount(n) == 1);
        FWT(A, n, op), FWT(B, n, op);
        for(int i = 0; i < n; ++i)
            A[i] = mul(A[i], B[i]);
        FWT(A, n, -op);
    }
};
子集卷积
求子集卷积:
序列 \(c\) ,满足 \(c_i=\sum\limits_{(j\& k)=0,(j| k)=i}a_j\cdot b_k\)
复杂度:\(O(n\log^2 n)\)
inline int cnt1(const int &x) {
    return __builtin_popcount(x);
}
const int MOD = 1e9 + 9;
const int INV2 = (MOD + 1) >> 1;
inline int add(const int &x, const int &y) {
    int r = x + y;
    if(r >= MOD)
        r -= MOD;
    return r;
}
inline int sub(const int &x, const int &y) {
    int r = x - y;
    if(r < 0)
        r += MOD;
    return r;
}
inline int mul(const int &x, const int &y) {
    ll r = (ll)x * y;
    if(r >= MOD)
        r %= MOD;
    return (int)r;
}
void FWT(int *a, int n) {
    for(int l = 1; l < n; l <<= 1) {
        for(int i = 0; i < n; i += (l << 1)) {
            for(int j = 0; j < l; ++j) {
                int x = a[i + j], y = a[i + j + l];
                a[i + j + l] = add(x, y);
            }
        }
    }
}
void IFWT(int *a, int n) {
    for(int l = 1; l < n; l <<= 1) {
        for(int i = 0; i < n; i += (l << 1)) {
            for(int j = 0; j < l; ++j) {
                int x = a[i + j], y = a[i + j + l];
                a[i + j + l] = sub(y, x);
            }
        }
    }
}
int ln, n;
const int MAXLOGN = 20;
int a[MAXLOGN + 1][1 << MAXLOGN];
int b[MAXLOGN + 1][1 << MAXLOGN];
int c[MAXLOGN + 1][1 << MAXLOGN];
void Solve() {
    ms(a), ms(b), ms(c);
    scanf("%d", &ln), n = 1 << ln;
    for(int i = 0; i < n; ++i)
        scanf("%d", &a[cnt1(i)][i]);
    for(int i = 0; i < n; ++i)
        scanf("%d", &b[cnt1(i)][i]);
    for(int x = 0; x <= ln; ++x) {
        FWT(a[x], n), FWT(b[x], n);
        for(int y = 0; y <= x; ++y)
            for(int i = 0; i < n; ++i)
                c[x][i] = add(c[x][i], mul(a[y][i], b[x - y][i]));
        IFWT(c[x], n);
    }
    for(int i = 0; i < n; ++i)
        printf("%d%c", c[cnt1(i)][i], " \n"[i == n - 1]);
}

 
                
            
         浙公网安备 33010602011771号
浙公网安备 33010602011771号