快速傅里叶变换NTT\FTT

前言:拼尽全力一知半解,唯一好处:可以一知半解的背板子。<。)#)))≦认为不会考。我也这样认为。但是能多学一点也是好的。

P3803 【模板】多项式乘法(FFT)

这是我的NTT模板题

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 3000005, mod = 998244353, g = 3, gi = 332748118;
int n, m, lim = 1, L = 0, r[N];
int a[N], b[N];
int fastpow(int a, int b)
{
    int res = 1;
    while(b)
    {
        if(b & 1) res = (ll)res * a % mod;
        a = (ll)a * a % mod;
        b >>= 1;
    }
    return res;
}
void NTT(int *A, int type)
{
    for (int i = 0; i < lim; ++ i) if(i < r[i]) swap(A[i], A[r[i]]);
    for (int mid = 1; mid < lim; mid <<= 1)
    {
        int wm = fastpow(type == 1 ? g : gi, (mod - 1) / (mid << 1));
        for (int j = 0; j < lim; j += (mid << 1))
        {
            int w = 1;
            for (int k = 0; k < mid; ++ k, w = ((ll)w * wm) % mod)
            {
                int x = A[j + k], y = (ll)w * A[j + k + mid] % mod;
                A[j + k] = (x + y) % mod;
                A[j + k + mid] = ((ll)x - y + mod) % mod;
            }
        }
    }
}
int main()
{
    scanf("%d %d", &n, &m);
    for (int i = 0; i <= n; ++ i) scanf("%d", &a[i]), a[i] %= mod;
    for (int i = 0; i <= m; ++ i) scanf("%d", &b[i]), b[i] %= mod;
    while(lim <= n + m) lim <<= 1, L ++;
    for (int i = 0; i < lim; ++ i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
    NTT(a, 1), NTT(b, 1);
    for (int i = 0; i < lim; ++ i) a[i] = ((ll)a[i] * b[i]) % mod;
    NTT(a, -1);
    int inv = fastpow(lim, mod - 2);
    for (int i = 0; i <= n + m; ++ i) 
    {
        printf("%d ", (ll)a[i] * inv % mod);
    }
    return 0;

}

【模板】高精度乘法 | A*B Problem 升级版

注意预处理的时候i是从[0,lim]闭区间

#include <bits/stdc++.h>
using namespace std;
const int N = 5000005;
const int mod = 998244353;
typedef long long ll;
int a[N], b[N], c[N], tmp[N], inv3, pow3[N], powinv3[N], r[N];
int lim = 1, L = 0, n, m;
char s[N];
int fastpow(int a, int b)
{
    int res = 1;
    while(b)
    {
        if(b & 1) res = (ll)res * a % mod;
        a = (ll)a * a % mod;
        b >>= 1;
    }
    return res;
}
void init()
{
    while(lim <= n + m) lim <<= 1, L ++;
    inv3 = fastpow(3, mod - 2);
    for (int i = 1; i <= lim; i <<= 1) pow3[i] = fastpow(3, (mod - 1) / i);
    for (int i = 1; i <= lim; i <<= 1) powinv3[i] = fastpow(inv3, (mod - 1) / i);
    for (int i = 0; i <= lim; ++ i)
    {
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << L - 1);
    }
    return ;
}
void NTT(int *A, int type)
{
    for (int i = 0; i < lim; ++ i) if(i < r[i]) swap(A[i], A[r[i]]);
    for (int mid = 1; mid < lim; mid <<= 1)
    {
        int wn;
        if(type == 1) wn = pow3[mid << 1];
        else wn = powinv3[mid << 1];
        for (int j = 0; j < lim; j += (mid << 1))
        {
            int w = 1;
            for (int k = 0; k < mid; ++ k, w = (ll)w * wn % mod)
            {
                int x = A[j + k], y = (ll)w * A[j + k + mid] % mod;
                A[j + k] = ((ll)x + y) % mod;
                A[j + k + mid] = ((ll)x - y + mod) % mod;
            }    
        }
    }
    if(type == -1)
    {
        int num = fastpow(lim, mod - 2);
         for (int i = 0; i < lim; ++ i) a[i] = (ll)a[i] * num % mod;
    }
    return ;
}
int main()
{
    scanf("%s", s + 1);
    n = strlen(s + 1) - 1;
    for (int i = 0; i <= n; ++ i) a[i] = s[n - i + 1] - '0';
    // cout << endl;
    scanf("%s", s + 1);
    m = strlen(s + 1) - 1;
    for (int i = 0; i <= m; ++ i) b[i] = s[m - i + 1] - '0';
    // cout << endl;
    init();
    NTT(a, 1);
    NTT(b, 1);
    for (int i = 0; i < lim; ++ i) a[i] = (ll)a[i] * b[i] % mod;
    NTT(a, -1);
    for (int i = 0; i < lim; ++ i) c[i] = a[i];
     for (int i = 0; i < lim; ++ i)
    {
        if(c[i] >= 10) 
        {
            c[i + 1] += c[i] / 10;
            c[i] %= 10;    
        }
    }
    int pp = lim;
    while(c[pp] == 0) pp --;
    for (int i = pp; i >= 0; -- i) printf("%d", c[i]);
    return 0;
}

1096G - Lucky Tickets

处理r翻转数组的时候忘记右移了,调了很久。转化很有技巧。多项式的n次方其实可以先预处理结果需要多少个数来求解,然后算一次,然后将对应的每一个函数值快速幂,再反解就可以了。

#include <bits/stdc++.h>
using namespace std;
const int N = 3000005;
const int mod = 998244353;
typedef long long ll;
int quick(int a, int b)
{
    int res = 1;
    while(b)
    {
        if(b & 1) res = (ll)res * a % mod;
        a = (ll)a * a % mod;
        b >>= 1;
    }
    return res;
}

int n, lim = 1, L = 0, a[N], k, mx = 0;
int r[N], inv3, num = 0; 
void NTT(int *A, int type)
{
    for (int i = 0; i < lim; ++ i) if(i < r[i]) swap(A[i], A[r[i]]);
    for (int mid = 1; mid < lim; mid <<= 1)
    {
        int wn = quick(3, (mod - 1) / (mid << 1));
        if(type == -1) wn = quick(wn, mod - 2);
        for (int j = 0; j < lim; j += (mid << 1))
        {
            int w = 1;
            for (int z = 0; z < mid; ++ z, w = (ll)w * wn % mod)
            {
                int x = A[j + z], y = (ll)w * A[j + z + mid] % mod;
                A[j + z] = (x + y) % mod;
                A[j + z + mid] = ((x - y) % mod + mod) % mod; 
            }
        }
    }
    if(type == -1)
    {
        for (int i = 0; i < lim; ++ i) A[i] = (ll)A[i] * num % mod;
    }
    return ;
}
int main()
{
    scanf("%d %d", &n, &k);
    for (int i = 1; i <= k; ++ i) 
    {
        int x;
        scanf("%d", &x);
        a[x] = 1;
        mx = max(mx, x);
    }
    mx = mx * (n / 2);
    lim = 1, L = 0;
    while(lim <= mx) lim <<= 1, L ++;
    num = quick(lim, mod - 2);
    inv3 = quick(3, mod - 2);
    for (int i = 0; i <= lim; ++ i) r[i] = (r[i >> 1] >> 1)| ((i & 1) << L - 1);//请注意
    NTT(a, 1);
    for (int i = 0; i < lim; ++ i) a[i] = quick(a[i], n / 2);
    NTT(a, -1);
    int ans = 0;
    for (int i = 0; i < lim; ++ i) ans = (ans + (ll)a[i] * a[i] % mod) % mod;
    printf("%d", ans);
    return 0;
}
posted @ 2025-02-13 10:28  Helioca  阅读(39)  评论(0)    收藏  举报
Document