题解 lugu P5591 小猪佩奇学数学

传送门


【分析】

单位根反演 + CZT

\(\begin{aligned} \sum_{i=0}^n\binom n i p^i\lfloor{i\over k}\rfloor&=\sum_{i=0}^n\binom n i p^i\cdot {i-(i\bmod k)\over k} \\&={1\over k}\left(\ p\sum_{i=0}^n \binom n i {\text d\over \text dp}p^i - \sum_{r=0}^{k-1}r\sum_{i=0}^n \binom n i p^i [k\mid (i-r)]\ \right) \\&={1\over k}\left(\ p\cdot {\text d\over \text dp}(p+1)^n-\sum_{r=0}^{k-1}r\sum_{i=0}^n \binom n i p^i\cdot {1\over k}\sum_{t=0}^{k-1}\omega_k^{(i-r)t}\ \right)&\text{(二项式定理、单位根反演)} \\&={1\over k}\left(\ p\cdot n(p+1)^{n-1}-{1\over k}\sum_{t=0}^{k-1}\sum_{r=0}^{k-1}r\cdot \omega_k^{-rt}\sum_{i=0}^n\binom n i (p\omega_k^t)^i\ \right) \\&={1\over k}\left(\ np\cdot (p+1)^{n-1}-{1\over k}\sum_{t=0}^{k-1}(\sum_{r=0}^{k-1}r\cdot \omega_k^{-rt})\cdot (p\omega_k^t+1)^n\ \right)&\text{(二项式定理)} \end{aligned}\)

\(\displaystyle f_t=\sum_{r=0}^{k-1}r\cdot \omega_k^{-rt}, g_t=(p\omega_k^t+1)^n\)

\(\displaystyle \sum_{i=0}^n\binom n i p^i\lfloor{i\over k}\rfloor={1\over k}\left(\ np\cdot (p+1)^{n-1}-{1\over k}\sum_{t=0}^{k-1}f_t\cdot g_t\ \right)\)

对于 \(np\cdot (p+1)^{n-1}\) 直接快速幂即可计算,对于 \(\displaystyle \sum_{t=0}^{k-1} f_t\cdot g_t\) 可以直接 \(O(n)\) 统计

只需要求出 \(\{f\}, \{g\}\) 即可


对于 \(\{g\}\) ,我们直接递推单位根 \(\omega_k^i\) ,然后乘上 \(p\) 再加 \(1\) ,之后跑快速幂即可

对于 \(\{f\}\) ,我们假定初始时 \(\displaystyle F(x)=\sum_{i=0}^{k-1}ix^i\)

同理,直接进行 ICZT 即可求出 \({1\over k}F(\omega_k^{-0}), {1\over k}F(\omega_k^{-1}), \cdots, {1\over k}F(\omega_k^{-(k-1)})\)

统一乘上 \(k\) 即可求出 \(f_0, f_1, \cdots , f_{k-1}\)

由于需要跑 ICZT ,代码中的数组需要开原来的两倍


【代码】

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<ll, ll> pii;
typedef double db;
#define fi first
#define se second

const int LimBit=21, M=1<<LimBit<<1, P=998244353;
inline ll kpow(ll a,ll x) { ll ans=1; for(;x;x>>=1,a=a*a%P) if(x&1) ans=ans*a%P; return ans; }
int a[M], b[M], c[M];
struct NTT{
    static const int G=3;
    int N, na, nb, invN;
    NTT() {}

    int w[2][M], rev[M];
    void work() {
        int d=__builtin_ctz(N);
        w[0][0]=w[1][0]=1;
        for(int i=1, x=kpow(G, (P-1)/N), y=kpow(x, P-2); i<N; ++i) {
            rev[i]=(rev[i>>1]>>1)|((i&1)<<d-1);
            w[0][i]=(ll)x*w[0][i-1]%P, w[1][i]=(ll)y*w[1][i-1]%P;
        }
        invN=kpow(N, P-2);
    }
    inline void FFT(int *a, int f) {
        for(int i=0;i<N;++i) if(i<rev[i]) swap(a[i], a[rev[i]]);
        for(int i=1;i<N;i<<=1)
            for(int j=0, t=N/(i<<1); j<N; j+=i<<1)
                for(int k=0, l=0, x, y; k<i; ++k, l+=t)
                    x=(ll)w[f][l]*a[j+k+i]%P, y=a[j+k], a[j+k+i]=(y-x+P)%P, a[j+k]=(y+x)%P;
        if(f) for(int i=0;i<N;++i) a[i]=(ll)a[i]*invN%P;
    }
    inline void doit(int *a, int *b, int na, int nb) {
        for(N=1;N<na+nb-1;N<<=1);
        for(int i=na;i<N;++i) a[i]=0;
        for(int i=0;i<nb;++i) c[i]=b[i];
        for(int i=nb;i<N;++i) c[i]=0;
        work(); FFT(a, 0); FFT(c, 0);
        for(int i=0;i<N;++i) a[i]=(ll)a[i]*c[i]%P;
        FFT(a, 1);
    }
}ntt;

struct CZT{
    int g, N, invN;
    int w[2][M+M];
    CZT():g(3), N(0) {}
    inline void work(int N_) {
        if(N==N_) return ;
        N=N_;
        w[0][0]=w[1][0]=1;
        for(int i=1, x=kpow(g, (P-1)/N), y=kpow(x, P-2); i<N; ++i) {
            w[0][i]=(ll)w[0][i-1]*x%P;
            w[1][i]=(ll)w[1][i-1]*y%P;
        }
        invN=kpow(N, P-2);
    }
    inline void czt(int *c, int f, int N) {
        work(N);
        for(int i=0, j; i<N; ++i) {
            j=i*(i-1ll)/2%N;
            a[i]=(ll)c[i]*w[f^1][j]%P;
            b[i]=w[f][j];
            c[i]=w[f^1][j];
        }
        for(int i=N; i<N+N-1; ++i)
            b[i]=w[f][i*(i-1ll)/2%N];
        reverse(a, a+N);
        ntt.doit(a, b, N, N+N-1);
        for(int i=0, j=N-1; i<N; ++i, ++j) c[i]=(ll)a[j]*c[i]%P;
        if(f) for(int i=0;i<N;++i) c[i]=(ll)c[i]*invN%P;
    }
}czt;

int nn, pp, kk, f[M], g[M];
inline void init() {
    cin>>nn>>pp>>kk;

    ntt.N=kk;
    ntt.work();
    for(int i=0;i<kk;++i)
        g[i]=kpow(((ll)pp*ntt.w[0][i]+1)%P, nn);
    
    for(int i=0;i<kk;++i) f[i]=i;
    czt.czt(f, 1, kk);
    for(int i=0;i<kk;++i) f[i]=(ll)f[i]*kk%P;
}
int main(){
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    init();
    int invk=kpow(kk, P-2), res=0;
    for(int i=0;i<kk;++i) res=(res+(ll)f[i]*g[i])%P;
    res=(ll)nn*pp%P*kpow(pp+1, nn-1)%P-(ll)invk*res%P;
    res=(ll)(res+P)%P*invk%P;
    cout<<res;
    cout.flush();
    return 0;
}
posted @ 2021-09-03 22:54  JustinRochester  阅读(43)  评论(0编辑  收藏  举报