[LOJ2541] [PKUWC2018] 猎人杀

题目链接

LOJ:https://loj.ac/problem/2541

Solution

很巧妙的思路。

注意到运行的过程中概率的分母在不停的变化,这样会让我们很不好算,我们考虑这样转化:假设所有人都活着,然后随机选一个人,如果此人已死那就重新选一次。

假设当前活着的人集合为\(T\),那么射中第\(i\)个人的概率就是:

\[\sum_{i=0}^{\infty}\left(\frac{s_{all}-s_T}{s_{all}}\right)^i\frac{w_i}{s_{all}}=\frac{w_i}{s_T} \]

其中\(s_p\)表示\(p\)集合的\(w\)总和,可以发现这样选的概率和原来是一样的。

我们考虑容斥,设\(f(T)\)表示至少\(T\)集合的人比\(1\)号后死,用一个很简单的容斥可以得到:

\[ans=\sum_{T}(-1)^{|T|}f(T) \]

那么大力算可以得到\(f\)

\[\begin{align}f(T)&=\sum_{i=0}^{\infty}\left(\frac{s_{all}-s_T-w_1}{s_{all}}\right)^i\cdot \frac{w_1}{s_{all}}\\&=\frac{w_1}{w_1+s_T}\end{align} \]

答案就是:

\[ans=\sum_T(-1)^{|T|}\frac{w_1}{w_1+s_T} \]

注意到\(s\)至多只有\(1e5\),我们可以背包算出每个\(s_T\)出现了多少次,背包的时候顺便把容斥系数带上。

这样做是\(O(ns)\)的,显然\(T\)掉了。

但是我们可以用生成函数优化这个东西,直接就是:

\[\prod_{i=2}^{n}(1-x^{w_i}) \]

然后分治\(FFT\)优化就好了,复杂度\(O(n\log ^2 n)\)

#include<bits/stdc++.h>
using namespace std;

void read(int &x) {
    x=0;int f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
    for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}

void print(int x) {
    if(x<0) putchar('-'),x=-x;
    if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}

#define lf double
#define ll long long 

#define pii pair<int,int >
#define vec vector<int >

#define pb push_back
#define mp make_pair
#define fr first
#define sc second

#define FOR(i,l,r) for(int i=l,i##_r=r;i<=i##_r;i++) 

const int maxn = 4e5+10;
const int inf = 1e9;
const lf eps = 1e-8;
const int mod = 998244353;

int w[maxn],pos[maxn],N,bit,f[maxn],a[maxn],s[maxn],n,mxn;

int add(int x,int y) {return x+y>=mod?x+y-mod:x+y;}
int del(int x,int y) {return x-y<0?x-y+mod:x-y;}
int mul(int x,int y) {return 1ll*x*y-1ll*x*y/mod*mod;}

int qpow(int a,int x) {
    int res=1;
    for(;x;x>>=1,a=mul(a,a)) if(x&1) res=mul(res,a);
    return res;
}

void prepare(int t) {
    for(N=1,bit=0;N<=t;N<<=1,bit++);mxn=N;w[0]=1,w[1]=qpow(3,(mod-1)/mxn);
    for(int i=2;i<=N;i++) w[i]=mul(w[i-1],w[1]);
}

void ntt_get(int t) {
    for(N=1,bit=0;N<=t;N<<=1,bit++);
    for(int i=1;i<N;i++) pos[i]=pos[i>>1]>>1|((i&1)<<(bit-1));
}

void ntt(int *r,int op) {
    for(int i=1;i<N;i++) if(pos[i]>i) swap(r[i],r[pos[i]]);
    for(int i=1,d=mxn>>1;i<N;i<<=1,d>>=1)
        for(int j=0;j<N;j+=i<<1)
            for(int k=0;k<i;k++) {
                int x=r[j+k],y=mul(r[i+j+k],w[k*d]);
                r[j+k]=add(x,y),r[i+j+k]=del(x,y);
            }
    if(op==-1) {
        reverse(r+1,r+N);int d=qpow(N,mod-2);
        for(int i=0;i<N;i++) r[i]=mul(r[i],d);
    }
}

int get(int lt,int rt) {
    int l=lt,r=rt,mid,ans=lt;
    while(l<=r) {
        mid=(l+r)>>1;
        if(s[rt]-s[mid]>=s[mid]-s[lt-1]) l=mid+1,ans=mid;
        else r=mid-1;
    }return ans;
}

void solve(int l,int r,int *t) {
    if(l>r) return ;
    if(l==r) {t[0]=1,t[a[l]]=mod-1;return ;}
    int d=1<<((int)ceil(log2(s[r]-s[l-1]))+1);
    int *sl=new int [d+10],*sr=new int [d+10],mid=get(l,r);
    for(int i=0;i<=d+5;i++) sl[i]=sr[i]=0;
    solve(l,mid,sl),solve(mid+1,r,sr);
    ntt_get(d>>1);ntt(sl,1),ntt(sr,1);
    for(int i=0;i<N;i++) t[i]=mul(sl[i],sr[i]);
    ntt(t,-1);delete sl;delete sr;
}

int main() {
    read(n);for(int i=1;i<=n;i++) read(a[i]),s[i]=s[i-1]+a[i];
    prepare(s[n]<<1);solve(2,n,f);int ans=0;
    for(int i=0;i<=s[n];i++) ans=add(ans,mul(qpow(a[1]+i,mod-2),f[i]));
    write(mul(ans,a[1]));
    return 0;
}
posted @ 2019-05-21 11:55  Hyscere  阅读(166)  评论(0编辑  收藏  举报