Luogu4859 二项式反演

今天学了一个叫二项式反演的有趣东西.

其实它的核心式子就两个

\[g_i=\sum_{j=i}^n\binom{j}{i}f[j]\\ f_i=\sum_{j=i}^n(-1)^{j-i}\binom{j}{i}g[j] \]

证明是用容斥证的.


现在我们看这道题.
题目链接

我们知道答案就是\(a>b\)的对数为\(\frac{n+k}{2}\)的方案数.
\(x=\frac{n+k}{2}\)
考虑普通\(dp\).
\(f[i][j]\)表示前\(i\)个数,已经有\(j\)\(a>b\)的方案.
那么\(f[i][j]=f[i-1][j]+(lb[i]-j)*f[i-1][j-1]\)
其中\(lb[i]\)=\(lowerbound(b,a[i])\)

我们令\(g[i]\)表示\(a>b\)的对数\(\geq x\)的方案数.
那么\(g[i]=f[n][i]*(n-i)!\)
就是钦点了\(i\)\(a>b\)匹配,然后剩下的随意匹配的方案数.
\(a[i]\)表示刚好有\(i\)\(a>b\)的方案数.
考虑对于\(\forall j\in [i,n],a[j]\)\(g[i]\)中计算了\(C_j^i\)次.
这个证明..显然
那么我们列出方程\(g_i=\sum_{j=i}^n\binom{j}{i}a_j\)
二项式反演得\(a_i=\sum_{j=i}^n(-1)^{j-i}\binom{j}{i}g_j\)

于是只要先\(O(n^2)dp\)一下,然后再\(O(n)\)计算一下即可.

代码如下

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<vector>
#define N (2010)
#define P (1000000009) 
#define inf (0x7f7f7f7f)
#define rg register int
#define Label puts("NAIVE")
#define spa print(' ')
#define ent print('\n')
#define rand() (((rand())<<(15))^(rand()))
typedef long double ld;
typedef long long LL;
typedef unsigned long long ull;
using namespace std;
inline char read(){
    static const int IN_LEN=1000000;
    static char buf[IN_LEN],*s,*t;
    return (s==t?t=(s=buf)+fread(buf,1,IN_LEN,stdin),(s==t?-1:*s++):*s++);
}
template<class T>
inline void read(T &x){
    static bool iosig;
    static char c;
    for(iosig=false,c=read();!isdigit(c);c=read()){
        if(c=='-')iosig=true;
        if(c==-1)return;
    }
    for(x=0;isdigit(c);c=read())x=((x+(x<<2))<<1)+(c^'0');
    if(iosig)x=-x;
}
inline char readchar(){
    static char c;
    for(c=read();!isalpha(c);c=read())
    if(c==-1)return 0;
    return c;
}
const int OUT_LEN = 10000000;
char obuf[OUT_LEN],*ooh=obuf;
inline void print(char c) {
    if(ooh==obuf+OUT_LEN)fwrite(obuf,1,OUT_LEN,stdout),ooh=obuf;
    *ooh++=c;
}
template<class T>
inline void print(T x){
    static int buf[30],cnt;
    if(x==0)print('0');
    else{
        if(x<0)print('-'),x=-x;
        for(cnt=0;x;x/=10)buf[++cnt]=x%10+48;
        while(cnt)print((char)buf[cnt--]);
    }
}
inline void flush(){fwrite(obuf,1,ooh-obuf,stdout);}
int n,m,a[N],b[N],lb[N],k;
LL f[N][N],g[N],ans,jc[N],inv[N];
LL ksm(LL a,int p){
    LL res=1;
    while(p){
        if(p&1)res=(res*a)%P;
        a=(a*a)%P,p>>=1;
    }
    return res;
}
LL C(int n,int m){return jc[n]*inv[n-m]%P*inv[m]%P;}
int main(){
    read(n),read(m),f[0][0]=1,jc[0]=inv[0]=1;
    if((n+m)&1){puts("0"),exit(0);}else k=(n+m)/2;
    for(int i=1;i<=n;i++)jc[i]=(jc[i-1]*1ll*i)%P,inv[i]=ksm(jc[i],P-2);
    for(int i=1;i<=n;i++)read(a[i]);
    for(int i=1;i<=n;i++)read(b[i]);
    sort(a+1,a+n+1),sort(b+1,b+n+1);
    for(int i=1;i<=n;i++)lb[i]=lower_bound(b+1,b+n+1,a[i])-b;
    for(int i=1;i<=n;i++){
        f[i][0]=f[i-1][0];
        for(int j=1;j<=i;j++)
        f[i][j]=(f[i-1][j-1]*(LL)max(0,lb[i]-j)%P+f[i-1][j])%P;
    }
    for(int i=1;i<=n;i++)g[i]=f[n][i]*jc[n-i]%P;
    for(int i=k;i<=n;i++){
        LL tp=((i-k)&1)?-1:1;
        (ans+=(1ll*tp*C(i,k)%P*g[i])%P)%=P;
    }
    printf("%lld\n",(ans+P)%P);
}
posted @ 2018-12-14 15:22  Romeolong  阅读(136)  评论(0编辑  收藏  举报