分治 FFT(多项式求逆,分治fft)

题目

source

题解

方法一:多项式求逆

\(g(0)=0\),原式子可写成

\[f_i=\sum\limits_{j=0}^{i}{f_{i-1}g_j} \]

\(f\)\(g\)看作多项式,等式右边即为\(f\times g\),这说明有\(f=f\times g\)。除了\(i=0\)时,\((f\times g)_0 = 0\neq f_1\)。因此把它补上,就有

\[\begin{aligned} f&\equiv f\times g + f_1 &\mod x^n \\ f(1-g)&\equiv f_1 &\mod x^n \\ f&\equiv f_1(1-g)^{-1} &\mod x^n \end{aligned} \]

求出\((1-g)^{-1}\)逆元即可。时间复杂度\(O(n\log n)\)

#include <bits/stdc++.h>

#define endl '\n'
#define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define mp make_pair
#define seteps(N) fixed << setprecision(N) 
typedef long long ll;

using namespace std;
/*-----------------------------------------------------------------*/

ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
#define INF 0x3f3f3f3f

const int N = 1e6 + 10;
const int M = 998244353;
int rev[N];
inline ll qpow(ll a, ll b, ll m) {
    ll res = 1;
    while (b) {
        if (b & 1)
            res = (res * a) % m;

        a = (a * a) % m;
        b = b >> 1;
    }
    return res;
}

void change(ll y[], int len) { // 蝴蝶变换
    for (int i = 0; i < len; ++i) {
        rev[i] = rev[i >> 1] >> 1;
        if (i & 1) {
            rev[i] |= len >> 1;
        }
    }
    for (int i = 0; i < len; ++i) {
        if (i < rev[i]) {
            swap(y[i], y[rev[i]]);
        }
    }
    return;
}

void ntt(ll y[], int len, int on) { // -1逆变换
    change(y, len);
    for (int h = 2; h <= len; h <<= 1) {
        ll gn = qpow(3, (M - 1) / h, M); // 原根为3
        if (on == -1)
            gn = qpow(gn, M - 2, M);
        for (int j = 0; j < len; j += h) {
            ll g = 1;

            for (int k = j; k < j + h / 2; k++) {
                ll u = y[k];
                ll t = g * y[k + h / 2] % M;
                y[k] = (u + t) % M;
                y[k + h / 2] = (u - t + M) % M;
                g = g * gn % M;
            }
        }
    }
    if (on == -1) {
        ll inv = qpow(len, M - 2, M);
        for (int i = 0; i < len; i++) {
            y[i] = y[i] * inv % M;
        }
    }
}

int get(int x) {
    int res = 1;
    while(res < x) {
        res <<= 1;
    }
    return res;
}

ll f[N], rf[N];

void solve(int len, ll x[], ll y[]) {
    if(len == 1) {
        y[0] = x[0];
        return ;
    }
    solve(len >> 1, x, y);
    for(int i = 0 ;i < (len << 1); i++) {
        f[i] = rf[i] = 0;
    }
    for(int i = 0; i < len / 2; i++) {
        rf[i] = y[i];
    }
    for(int i = 0; i < len; i++) {
        f[i] = x[i];
    }
    ntt(f, len << 1, 1);
    ntt(rf, len << 1, 1);
    for(int i = 0; i < (len << 1); i++) {
        rf[i] = rf[i] * (2 - rf[i] * f[i] % M + M) % M;
    }
    ntt(rf, len << 1, -1);
    for(int i = 0; i < len; i++) y[i] = rf[i];
}

ll a[N], b[N];

int main() {
    IOS;
    int n;
    cin >> n;
    for(int i = 1; i < n; i++) {
        cin >> a[i];
    }
    for(int i = 0; i < n; i++) {
        a[i] = ((i == 0) - a[i] % M + M) % M;
    }
    int len = get(n);
    solve(len, a, b);
    for(int i = 0; i < n; i++) {
        cout << b[i] << " \n"[i == n - 1];
    }

}

方法二:分治fft

这个后一项的计算需要依赖前一项的贡献,所以用cdq分治计算,步骤大致上为:

  • 计算左半部分;
  • 转移左半部分的贡献到右半部分;
  • 完成右半部分的计算。

注意转移贡献时,一个区间是左半部分区间[l,mid],另外一个区间是[0,r-l+1](转移起点是0),这样才能覆盖到右半部分区间[mid+1,r]

详见代码

#include <bits/stdc++.h>

#define endl '\n'
#define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define mp make_pair
#define seteps(N) fixed << setprecision(N) 
typedef long long ll;

using namespace std;
/*-----------------------------------------------------------------*/

ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
#define INF 0x3f3f3f3f

const int N = 3e5 + 10;
const int M = 998244353;
const double eps = 1e-5;

int rev[N];

inline ll qpow(ll a, ll b, ll m) {
    ll res = 1;
    while(b) {
        if(b & 1) res = (res * a) % m;
        a = (a * a) % m;
        b = b >> 1;
    }
    return res;
}

void change(ll y[], int len) {
    for(int i = 0; i < len; ++i) {
        rev[i] = rev[i >> 1] >> 1;
        if(i & 1) {
            rev[i] |= len >> 1;
        }
    }
    for(int i = 0; i < len; ++i) {
        if(i < rev[i]) {
            swap(y[i], y[rev[i]]);
        }
    }
    return;
}

void fft(ll y[], int len, int on) {
    change(y, len);
    for(int h = 2; h <= len; h <<= 1) {
        ll gn = qpow(3, (M - 1) / h, M);
        if(on == -1) gn = qpow(gn, M - 2, M);
        for(int j = 0; j < len; j += h) {
            ll g = 1;
            for(int k = j; k < j + h / 2; k++) {
              ll u = y[k];
              ll t = g * y[k + h / 2] % M;
              y[k] = (u + t) % M;
              y[k + h / 2] = (u - t + M) % M;
              g = g * gn % M;
            }
        }
    }
    if(on == -1) {
        ll inv = qpow(len, M - 2, M);
        for(int i = 0; i < len; i++) {
            y[i] = y[i] * inv % M;
        }
    }
}

int get(int x) {
    int res = 1;
    while(res < x) {
        res <<= 1;
    }
    return res;
}

ll f[N], g[N];
ll tf[N], tg[N];

void solve(int l, int r) {
    if(l == r) {
        return ;
    }
    int len = (r - l + 1);
    int mid = (l + r) / 2;
    solve(l, mid);
    for(int i = 0; i < len; i++) {
        tf[i] = tg[i] = 0;
    }
    for(int i = l; i <= mid; i++) {
        tf[i - l] = f[i];
    }
    for(int i = l; i <= r; i++) {
        tg[i - l] = g[i - l];
    }
    fft(tf, len, 1);
    fft(tg, len, 1);
    for(int i = 0; i < len; i++) tf[i] = tf[i] * tg[i] % M;
    fft(tf, len, -1);
    for(int i = mid + 1; i <= r; i++) {
        f[i] = (f[i] + tf[i - l]) % M;
    }
    solve(mid + 1, r);
}

int main() {
    IOS;
    int n;
    cin >> n;
    for(int i = 1; i < n; i++) {
        cin >> g[i];
    }
    f[0] = 1;
    solve(0, get(n) - 1);
    for(int i = 0; i < n; i++) cout << f[i] << " \n"[i == n - 1];
}
posted @ 2021-09-29 00:00  limil  阅读(220)  评论(0编辑  收藏  举报