题解:AT_abc422_g Balls and Boxes

更差的阅读体验


【模板】生成函数


Problem 1

无标号数球盒问题,果断想到普通生成函数。

对于第一个盒子,如果盒子个数是 \(A\) 的倍数,就产生 \(1\) 的贡献。容易构造

\[f_A(x) = \sum _{k \isin \N} x^{kA} \]

同理有

\[f_B(x) = \sum _{k \isin \N} x^{kB} \]

\[f_C(x) = \sum _{k \isin \N} x^{kC} \]

我们要求的是划分 \(n\) 个球的方案数,也就是

\[[x^n]f_A(x)f_B(x)f_C(x) \]

NTT 即可。

Problem 2

有标号数球盒问题,果断想到指数生成函数。

同上,为了消除标号的贡献,我们要在每一项的系数上除以一个排列数。容易构造

\[f_A(x) = \sum _{k \isin \N} \frac{x^{kA}}{(kA)!} \]

\[f_B(x) = \sum _{k \isin \N} \frac{x^{kB}}{(kB)!} \]

\[f_C(x) = \sum _{k \isin \N} \frac{x^{kC}}{(kC)!} \]

记得把标号乘回来。答案是

\[[\frac{x^n}{n!}]f_A(x)f_B(x)f_C(x) \]

NTT 即可。


复杂度 \(O(n \log n)\),代码没人会想看的。

#include<bits/stdc++.h>
#define int long long
#define endl '\n'
#define N 2000006
using namespace std;
namespace POLY{ //by dyc2022
    const int MOD=998244353,G=3,invg=(MOD+1)/3,img=86583718;
    int r[N];
    inline int qpow(int x,int y)
    {
        if(y==0)return 1;
        if(y==1)return x%MOD;
        int ret=qpow(x,y>>1);
        return ret*ret%MOD*qpow(x,y&1)%MOD;
    }
    inline void NTT(int len,int *a,int opt)
    {
        for(int i=0;i<len;i++)if(i<r[i])
            swap(a[i],a[r[i]]);
        for(int i=1;i<len;i<<=1)
        {
            int tmp=i<<1,Wn=qpow(opt==1?G:invg,(MOD-1)/tmp);
            for(int j=0;j<len;j+=tmp)
            {
                int w=1,x,y;
                for(int k=0;k<i;k++,w=w*Wn%MOD)
                    x=a[j+k],y=w*a[i+j+k]%MOD,a[j+k]=(x+y)%MOD,a[i+j+k]=(x+MOD-y)%MOD;
            }
        }
        if(opt!=1)
        {
            int invn=qpow(len,MOD-2);
            for(int i=0;i<len;i++)a[i]=a[i]*invn%MOD;
        }
    }
    inline void times(int n,int m,int *a,int *b,int *ans)
    {
        int len=1,lg=0;
        while(len<=n+m)len<<=1,lg++;
        for(int i=0;i<len;i++)
            r[i]=(r[i>>1]>>1)|((i&1)<<lg-1);
        NTT(len,a,1),NTT(len,b,1);
        for(int i=0;i<=len;i++)a[i]=a[i]*b[i]%MOD;
        NTT(len,a,-1);
        for(int i=0;i<=n+m;i++)ans[i]=a[i];
    }
    int c[N];
    inline void getinv(int n,int *a,int *b)
    {
        if(n==1)return b[0]=qpow(a[0],MOD-2),(void)0;
        getinv(n+1>>1,a,b);
        int len=1,lg=0;
        while(len<(n<<1))len<<=1,lg++;
        for(int i=0;i<len;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(lg-1));
        for(int i=0;i<n;i++)c[i]=a[i];
        for(int i=n;i<len;i++)c[i]=0;
        NTT(len,c,1),NTT(len,b,1);
        for(int i=0;i<len;i++)
            b[i]=(2+MOD-c[i]*b[i]%MOD)%MOD*b[i]%MOD;
        NTT(len,b,-1);
        for(int i=n;i<len;i++)b[i]=0;
        memset(c,0,sizeof(c));
    }
    int f[N],g[N],fr[N],gr[N],invgr[N],dr[N],td[N],tmp[N];
    inline void divide(int n,int m,int *tf,int *tg,int *d)
    {
        for(int i=0;i<n+m<<1;i++)
            f[i]=g[i]=fr[i]=gr[i]=invgr[i]=dr[i]=tmp[i]=0;
        for(int i=0;i<n;i++)f[i]=tf[i];
        for(int i=0;i<m;i++)g[i]=tg[i];
        for(int i=0;i<n;i++)fr[n-i-1]=f[i];
        for(int i=0;i<m;i++)gr[m-i-1]=g[i];
        getinv(n-m+1,gr,invgr),times(n,n-m+1,fr,invgr,dr);
        for(int i=n-m;~i;i--)d[n-m-i]=dr[i];
    }
    inline void modulo(int n,int m,int *tf,int *tg,int *r)
    {
        for(int i=0;i<n+m<<1;i++)
            f[i]=g[i]=fr[i]=gr[i]=invgr[i]=dr[i]=tmp[i]=td[i]=0;
        for(int i=0;i<n;i++)f[i]=tf[i];
        for(int i=0;i<m;i++)g[i]=tg[i];
        for(int i=0;i<n;i++)fr[n-i-1]=f[i];
        for(int i=0;i<m;i++)gr[m-i-1]=g[i];
        getinv(n-m+1,gr,invgr),times(n,n-m+1,fr,invgr,dr);
        for(int i=n-m;~i;i--)td[n-m-i]=dr[i];
        times(m,n-m+1,g,td,tmp);
        for(int i=0;i<m-1;i++)
            r[i]=(f[i]+MOD-tmp[i])%MOD;
    }
    int *p[N<<2],length[N<<2],ta[N],tb[N];
    inline void init(int len,int lg){for(int i=0;i<len;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(lg-1));}
    inline void getp(int u,int l,int r,int *a)
    {
        length[u]=r-l+1,p[u]=new int[length[u]+1];
        if(l==r)
            return p[u][0]=(MOD-a[l]),p[u][1]=1,(void)0;
        int mid=l+r>>1;
        getp(u<<1,l,mid,a),getp(u<<1|1,mid+1,r,a);
        int len=1,lg=0;
        while(len<(length[u]+1<<1))len<<=1,lg++;
        init(len,lg);
        for(int i=0;i<=length[u<<1];i++)ta[i]=p[u<<1][i];
        for(int i=length[u<<1]+1;i<len;i++)ta[i]=0;
        for(int i=0;i<=length[u<<1|1];i++)tb[i]=p[u<<1|1][i];
        for(int i=length[u<<1|1]+1;i<len;i++)tb[i]=0;
        NTT(len,ta,1),NTT(len,tb,1);
        for(int i=0;i<len;i++)ta[i]=ta[i]*tb[i]%MOD;
        NTT(len,ta,-1);
        for(int i=0;i<=length[u];i++)p[u][i]=ta[i];
    }
    inline void solve(int u,int l,int r,int *a,int *f,int *ans)
    {
        if(length[u]<=500)
        {
            int m=length[u]-1;
            for(int i=l;i<=r;i++)
                for(int j=m;~j;j--)ans[i]=(ans[i]*a[i]+f[j])%MOD;
            return;
        }
        if(l==r)return ans[l]=*f,(void)0;
        int mid=l+r>>1,md[length[u]+2]={0};
        modulo(length[u],length[u<<1]+1,f,p[u<<1],md);
        solve(u<<1,l,mid,a,md,ans);
        modulo(length[u],length[u<<1|1]+1,f,p[u<<1|1],md);
        solve(u<<1|1,mid+1,r,a,md,ans);
    }
    inline void evaluation(int n,int m,int *a,int *f,int *ans)
    {
        getp(1,1,m,a);
        if(n>m)modulo(n,m+1,f,p[1],f);
        solve(1,1,m,a,f,ans);
    }
    inline void getdev(int n,int *f,int *g)
    {
        for(int i=1;i<n;i++)g[i-1]=i*f[i]%MOD;
        g[n-1]=0;
    }
    inline void getinvdev(int n,int *f,int *g)
    {
        for(int i=1;i<n;i++)
            g[i]=f[i-1]*qpow(i,MOD-2)%MOD;
        g[0]=0;
    }
    inline void getln(int n,int *f,int *g)
    {
        memset(ta,0,sizeof(ta));
        memset(tb,0,sizeof(tb));
        getdev(n,f,ta),getinv(n,f,tb);
        int lg=0,len=1;
        while(len<(n<<1))lg++,len<<=1;
        init(len,lg),NTT(len,ta,1),NTT(len,tb,1);
        for(int i=0;i<len;i++)ta[i]=ta[i]*tb[i]%MOD;
        NTT(len,ta,0),getinvdev(n,ta,g);
    }
    int t[N];
    inline void getexp(int n,int *f,int *g)
    {
        if(n==1)return g[0]=1,(void)0;
        getexp(n+1>>1,f,g),getln(n,g,t);
        int len=1,lg=0;
        while(len<=(n<<1))len<<=1,lg++;
        for(int i=1;i<len;i++)r[i]=(r[i>>1]>>1)|((i&1)<<lg-1);
        memset(ta,0,sizeof(ta));
        for(int i=0;i<n;i++)ta[i]=f[i];
        for(int i=n;i<len;i++)t[i]=ta[i]=0;
        for(int i=0;i<len;i++)ta[i]=(ta[i]-t[i]+MOD)%MOD;
        ta[0]++,NTT(len,g,1),NTT(len,ta,1);
        for(int i=0;i<len;i++)g[i]=g[i]*ta[i]%MOD;
        NTT(len,g,0);
        for(int i=n;i<len;i++)g[i]=0;
        for(int i=0;i<len;i++)ta[i]=0;
    }
    int pwtmp[N];
    inline void getpow(int n,int *f,int *g,int k)
    {
        memset(pwtmp,0,sizeof(pwtmp)),getln(n,f,pwtmp);
        for(int i=0;i<n;i++)pwtmp[i]=pwtmp[i]*k%MOD;
        getexp(n,pwtmp,g);
    }
    int tx1[N],tx2[N],tx3[N];
    inline void getsin(int n,int *f,int *g)
    {
        memset(tx1,0,sizeof(tx1));
        memset(tx2,0,sizeof(tx2));
        memset(tx3,0,sizeof(tx3));
        for(int i=0;i<n;i++)tx1[i]=f[i]*img%MOD;
        getexp(n,tx1,tx2),getinv(n,tx2,tx3);
        for(int i=0;i<n;i++)g[i]=(tx2[i]-tx3[i]+MOD)%MOD*qpow(img<<1,MOD-2)%MOD;
    }
    inline void getcos(int n,int *f,int *g)
    {
        memset(tx1,0,sizeof(tx1));
        memset(tx2,0,sizeof(tx2));
        memset(tx3,0,sizeof(tx3));
        for(int i=0;i<n;i++)tx1[i]=f[i]*img%MOD;
        getexp(n,tx1,tx2),getinv(n,tx2,tx3);
        for(int i=0;i<n;i++)g[i]=(tx2[i]+tx3[i])%MOD*(MOD+1>>1)%MOD;
    }
}
using POLY::MOD;
using POLY::qpow;
int n,a,b,c,fac[N],ifac[N];
int f1[N],f2[N],f3[N],g1[N],g2[N];
void init()
{
    fac[0]=1;
    for(int i=1;i<N;i++)fac[i]=fac[i-1]*i%MOD;
    ifac[N-1]=qpow(fac[N-1],MOD-2);
    for(int i=N-2;~i;i--)ifac[i]=ifac[i+1]*(i+1)%MOD;
}
void solve1()
{
    for(int i=0;i*a<=n;i++)f1[i*a]=1;
    for(int i=0;i*b<=n;i++)f2[i*b]=1;
    for(int i=0;i*c<=n;i++)f3[i*c]=1;
    POLY::times(n+1,n+1,f1,f2,g1);
    POLY::times(n+1,n+1,g1,f3,g2);
    printf("%lld\n",g2[n]);
}
void solve2()
{
    init();
    for(int i=0;i<N;i++)f1[i]=f2[i]=f3[i]=g1[i]=g2[i]=0;
    for(int i=0;i*a<=n;i++)f1[i*a]=ifac[i*a];
    for(int i=0;i*b<=n;i++)f2[i*b]=ifac[i*b];
    for(int i=0;i*c<=n;i++)f3[i*c]=ifac[i*c];
    POLY::times(n+1,n+1,f1,f2,g1);
    POLY::times(n+1,n+1,g1,f3,g2);
    printf("%lld\n",g2[n]*fac[n]%MOD);
}
main()
{
    scanf("%lld%lld%lld%lld",&n,&a,&b,&c);
    solve1(),solve2();
    return 0;
}
posted @ 2025-09-08 16:51  dyc2022  阅读(22)  评论(0)    收藏  举报
/* 设置动态特效 */ /* 设置文章评论功能 */ 返回顶端 levels of contents