【LOJ2541】猎人杀(PKUWC2018)-容斥+级数+分治NTT

测试地址:猎人杀
做法:本题需要用到容斥+级数+分治NTT。
要求1号最后一个被射杀,其实就是要求所有人都不能在1号后被射杀。这种要求全部条件满足求方案数/概率的情况,就要考虑容斥,即枚举一个集合S,计算强制这S个人在1号后被射杀的概率p(S),那么答案就等于:
ans=S(1)|S|p(S)
可是由于游戏的每一步中,概率的分母都不同,p(S)很难计算,怎么办呢?我们需要对游戏做出一些转化:一个人被射杀后,他仍然参与概率的计算,但如果射中了已经被射杀的人,就再射一次,显然这和原来的游戏是等价的。这样的话,令sum(S)=iSwi,W=i=1nwi,我们有:
p(S)=i=0(1w1+sum(S)W)iw1W
w1W提出来后,剩下的和式是一个无穷级数,因为0<1w1+sum(S)W<1,所以这个级数是收敛的,那么它就等于前缀和数列的极限。我们有公式:
i=0xi=11x
所以有:
p(S)=w1WWw1+sum(S)=w1w1+sum(S)
于是有:
ans=w1S(1)|S|w1+sum(S)
虽然我们极大地简化了所求的式子,但是这个还是不太好求。这时我们注意到一个条件:i=1nwi105,这启发我们分开计算每种分母的贡献。于是我们构造一个生成函数,其中xi项的系数就表示分母为i的数对答案贡献的分子,我们怎么算出这个生成函数呢?注意到这就等于xw1i=2n(x0xwi),于是分治NTT求出后面的部分即可。这里的分治NTT就是单纯的分治+NTT,而不是CDQ分治+NTT。于是我们就解决了这一题,时间复杂度为O(WlogWlogn)
以下是本人代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=998244353;
const ll g=3;
int n,sum,rev[200010],cnt=0,siz[30];
ll w[200010],A[30][200010];

ll power(ll a,ll b)
{
    ll s=1,ss=a;
    b=(b+mod-1)%(mod-1);
    while(b)
    {
        if (b&1) s=s*ss%mod;
        ss=ss*ss%mod;b>>=1;
    }
    return s;
}

ll NTT(ll *a,int n,int type)
{
    for(int i=0;i<n;i++)
        if (i<rev[i]) swap(a[i],a[rev[i]]);
    for(int mid=1;mid<n;mid<<=1)
    {
        ll W=power(g,type*(mod-1)/(mid<<1));
        for(int l=0,G=(mid<<1);l<n;l+=G)
        {
            ll w=1;
            for(int k=0;k<mid;k++,w=w*W%mod)
            {
                ll x=a[l+k],y=w*a[l+mid+k]%mod;
                a[l+k]=(x+y)%mod;
                a[l+mid+k]=(x-y+mod)%mod;
            }
        }
    }
    if (type==-1)
    {
        ll inv=power(n,mod-2);
        for(int i=0;i<n;i++)
            a[i]=a[i]*inv%mod;
    }
}

void solve(int l,int r)
{
    if (l==r)
    {
        cnt++;
        A[cnt][0]=1,A[cnt][w[l]]=mod-1;
        siz[cnt]=w[l];
        for(int i=1;i<w[l];i++)
            A[cnt][i]=0;
        return;
    }

    int mid=(l+r)>>1;
    solve(l,mid);
    solve(mid+1,r);

    int bit=0,x=1,a=cnt-1,b=cnt,tot=siz[a]+siz[b];
    while(x<=tot) bit++,x<<=1;
    rev[0]=0;
    for(int i=1;i<x;i++)
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
    for(int i=siz[a]+1;i<x;i++)
        A[a][i]=0;
    for(int i=siz[b]+1;i<x;i++)
        A[b][i]=0;
    NTT(A[a],x,1),NTT(A[b],x,1);
    for(int i=0;i<x;i++)
        A[a][i]=A[a][i]*A[b][i]%mod;
    NTT(A[a],x,-1);

    cnt--;
    siz[cnt]=tot;
}

int main()
{
    scanf("%d",&n);
    for(int i=1;i<=n;i++)
    {
        scanf("%lld",&w[i]);
        sum+=w[i];
    }

    if (n==1) printf("1");
    else
    {
        solve(2,n);
        ll ans=0;
        for(int i=0;i<=sum;i++)
            ans=(ans+A[1][i]*power(w[1]+i,mod-2))%mod;
        printf("%lld",ans*w[1]%mod); 
    }

    return 0; 
}
posted @ 2018-06-16 17:31  Maxwei_wzj  阅读(111)  评论(0编辑  收藏  举报