Loading

[ECNU] 3314. 多项式展开 (1)

https://acm.ecnu.edu.cn/problem/3314/

image


二项式展开得:

\[\begin{aligned} &\sum_{i=0}^{n}a_i(x+A)^i \\ =&\sum_{i=0}^{n}a_i\sum_{j=0}^{i}{i \choose j} x^jA^{i-j} \\ =&\sum_{j=0}^{n}x^j\sum_{i=j}^{n}{i \choose j}a_iA^{i-j} \\ =&\sum_{j=0}^{n}x^j\sum_{i=j}^{n}\frac{i!}{j!(i-j)!}a_iA^{i-j} \\ &b_i=\sum_{j=i}^{n}\frac{j!}{i!(j-i)!}a_jA^{j-i} \\ &i! \cdot A^i \cdot b_i=\sum_{j=i}^{n}\frac{j!}{(j-i)!}a_jA^{j} \\ & p_i=\sum_{j=i}^{n}q_jw_{j-i} \\ & p_i=\sum_{j=0}^{n-i}q_{j+i}w_j \\ & p'_i=p_{n-i} \\ & p'_{n-i}=\sum_{j=0}^{n-i}q_{j+i}w_j \\ & p'_i=\sum_{j=0}^{i}q_{j+n-i}w_j \\ & q'_{i}=q_{n-i},q_{j+n-i}=q'_{i-j} \\ & p'_i=\sum_{j=0}^{i}q'_{i-j}w_j \end{aligned} \]

只需要对\(q',w\)求一次卷积即可

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod = 998244353;
const int N = 2e6;
ll pw(ll a, ll b) {
    ll res = 1;
    while(b) {
        if(b & 1) {
            res = res * a % mod;
        }
        a = a * a % mod;
        b >>= 1;
    }
    return res;
}
ll fac[N], invfac[N];
ll a[N], q[N], q_[N], w[N], b[N];
ll p_[N], p[N];
ll A, pwA[N], invpwA[N];
int n;

namespace NTT {
    const int mod = 998244353, g[2] = { 3, (mod + 1) / 3 };
    ll a[N], b[N], f[N];
    int n, m;
    int rev(int x, int n) {
        int r = 0;
        for(int i = 0 ; (1 << i) < n ; ++ i) {
            r = (r << 1) | ((x >> i) & 1);
        }
        return r;
    }
    void ntt(ll *a, int ty, int n) {
        for(int i = 0 ; i < n ; ++ i) {
            f[rev(i, n)] = a[i];
        }
        for(int i = 2 ; i <= n ; i <<= 1) {
            ll wn = pw(g[ty], (mod - 1) / i);
            for(int j = 0 ; j < n ; j += i) {
                ll w = 1;
                for(int k = j ; k < j + i / 2 ; ++ k) {
                    ll u = f[k], v = f[k + i / 2] * w % mod;
                    f[k] = (u + v) % mod;
                    f[k + i / 2] = (u - v) % mod;
                    w = w * wn % mod;
                }
            }
        }
        for(int i = 0, inv = pw(n, mod - 2) ; i < n ; ++ i) {
            a[i] = f[i];
            if(ty) {
                a[i] = a[i] * inv % mod;
            }
        }
    }
    void ntt() {
        int len = 1;
        while(len <= 2 * (n + m)) len <<= 1;
        ntt(a, 0, len), ntt(b, 0, len);
        for(int i = 0 ; i < len ; ++ i) {
            a[i] = a[i] * b[i] % mod;
        }
        ntt(a, 1, len);
    }
};

int main() {
    scanf("%d", &n);
    for(int i = 0 ; i <= n ; ++ i) {
        scanf("%lld", &a[i]);
    }
    scanf("%lld", &A);

    if(A == 0) {
        for(int i = 0 ; i <= n ; ++ i) {
            printf("%lld ", a[i]);
        }
        return 0;
    }

    pwA[0] = 1;
    for(int i = 1 ; i <= n ; ++ i) {
        pwA[i] = pwA[i - 1] * A % mod;
    }
    invpwA[n] = pw(pwA[n], mod - 2);
    for(int i = n - 1 ; i >= 0 ; -- i) {
        invpwA[i] = invpwA[i + 1] * A % mod;
    }
    fac[0] = 1;
    for(int i = 1 ; i <= n ; ++ i) {
        fac[i] = fac[i - 1] * i % mod;
    }
    invfac[n] = pw(fac[n], mod - 2);
    for(int i = n - 1 ; i >= 0 ; -- i) {
        invfac[i] = invfac[i + 1] * (i + 1) % mod;
    }
    for(int i = 0 ; i <= n ; ++ i) {
        w[i] = invfac[i];
    }
    for(int i = 0 ; i <= n ; ++ i) {
        q[i] = fac[i] * a[i] % mod * pw(A, i) % mod;
        q_[n - i] = q[i];
    }

    NTT :: n = NTT :: m = n;
    for(int i = 0 ; i <= n ; ++ i) {
        NTT :: a[i] = q_[i];
        NTT :: b[i] = w[i];
    }
    NTT :: ntt();
    for(int i = 0 ; i <= n ; ++ i) {
        p[i] = NTT :: a[n - i];
    }
    for(int i = 0 ; i <= n ; ++ i) {
        b[i] = p[i] * invfac[i] % mod * pw(pw(A, i), mod - 2) % mod;
    }

    for(int i = 0 ; i <= n ; ++ i) {
        printf("%lld ", (b[i] + mod) % mod);
    }
}
posted @ 2021-10-08 09:21  nekko  阅读(225)  评论(0)    收藏  举报