【BZOJ3129】方程(SDOI2013)-容斥原理+扩展Lucas定理

测试地址:方程
做法:本题需要用到容斥原理+扩展Lucas定理。
首先,如果没有任何限制,那么非负整数解的数量就是Cm+n1n1,这个可以用隔板法求出,那么要求正整数解的话,其实只要转化成求(xi+1)=m的非负整数解数量即可,显然上面的方程可以转化为xi=mn来求。
现在我们考虑限制,对于第二种限制,我们可以把xi转化为xi+Ai1,就可以转化成求另一个方程的正整数解数。然而第一种限制就不好办了,这时候看到n18,想到容斥,即我们如果要求满足所有条件的方案数,我们可以用没有任何限制的方案数,减去强制某一个条件满足不了的方案数,加上强制某两个条件满足不了的方案数……显然如果一个限制满足不了,就会变成xi>a的形式,这个就能很容易地化成第二种限制求解了。
那么最后一个问题,也是这道题最困难的一个点,就是求大组合数取模。我们知道如果p为质数,可以用Lucas定理求出,但这里p可能是合数,我们只能将其质因数分解为p1t1×p2t2×...×pktk这样的形式,然后分别求组合数对p1t1,p2t2,...,pktk取模的结果,然后用中国剩余定理或合并模线性方程的方法将最后的结果求出。
现在问题变成求组合数对pt取模的结果,其中p为质数。我们可以根据公式Cnm=n!m!(nm)!将问题转化为求n!pt取模的结果。我们把1×2×...×n拆成长度为p的若干段,然后把所有p的倍数提出来,显然这些东西乘起来等于pnp×np!,对于np!递归求解,对于剩下的部分,我们可以把连续的pt1段看成一个周期,因为pt+11(modpt),那么我们可以先直接快速幂求出完整的npt个周期的乘积,然后剩下的部分最多长为pt,直接计算即可。这一个部分的时间复杂度应该是O(lognpt),如果预处理出缺项前缀积(就是把p的倍数舍掉的阶乘)常数会小很多。
这里要使用欧拉定理:aφ(n)1(modn)来求出逆元,显然由欧拉函数的定义有φ(pt)=(p1)pt1。特别地,如果结果中包含因子p,那么我们无法直接求得逆元,所以我们在计算时要独立计算因子p的幂数,仅对不含因子p的部分求逆元即可。
BZOJ的题面中缺了一个比较重要的信息,数据范围中的p是固定的一些数,这些数中最大的可分解出的pt差不多在10000左右的水准,所以上述的O(2n1lognpt)的算法可以通过此题。
以下是本人代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int T,n1,n2,totcnt=0;
ll n,m,a[20],p,fac[10],cnt[10],F[20010];

ll power(ll a,ll b,ll p)
{
    ll s=1,ss=a;
    while(b)
    {
        if (b&1) s=s*ss%p;
        ss=ss*ss%p,b>>=1;
    }
    return s;
}

ll exgcd(ll a,ll b)
{
    ll x0=1,y0=0,x1=0,y1=1;
    while(b)
    {
        ll q=a/b,tmp;
        tmp=x0,x0=x1,x1=tmp-q*x1;
        tmp=y0,y0=y1,y1=tmp-q*y1;
        tmp=a,a=b,b=tmp%b;
    }
    return x0;
}

ll calc(ll n,ll p,ll t,ll P,ll &s)
{
    s=0;
    if (n<=0) return 1;
    ll x=1;
    x=power(F[P],n/P,P);
    x=x*F[n%P]%P;
    x=x*calc(n/p,p,t,P,s)%P;
    s+=n/p;
    return x;
}

void init(ll p,ll P)
{
    F[0]=1;
    for(ll i=1;i<=P;i++)
    {
        if (i%p!=0) F[i]=F[i-1]*i%P;
        else F[i]=F[i-1];
    }
}

ll exlucas(ll n,ll m)
{
    if (m>n) return 0;
    ll lastr,lasta;
    for(int i=1;i<=totcnt;i++)
    {
        ll x,sumt=0,t;
        ll P=power(fac[i],cnt[i],p+1);
        init(fac[i],P);
        x=calc(n,fac[i],cnt[i],P,t);
        sumt+=t;
        x=x*power(calc(m,fac[i],cnt[i],P,t),P/fac[i]*(fac[i]-1)-1,P)%P;
        sumt-=t;
        x=x*power(calc(n-m,fac[i],cnt[i],P,t),P/fac[i]*(fac[i]-1)-1,P)%P;
        sumt-=t;
        x=x*power(fac[i],sumt,P)%P;

        if (i>1)
        {
            ll x0;
            x0=exgcd(lasta,P);
            x0=(x0*(x-lastr)%P+P)%P;
            lastr=(lastr+lasta*x0)%(lasta*P);
            lasta=lasta*P;
        }
        else lastr=x,lasta=P;
    }
    return lastr;
}

int main()
{
    scanf("%d%lld",&T,&p);
    ll x=p;
    for(ll i=2;i<=(ll)sqrt(p)+1;i++)
        if (x%i==0)
        {
            fac[++totcnt]=i;
            cnt[totcnt]=0;
            while(x%i==0)
            {
                cnt[totcnt]++;
                x/=i;
            }
        }
    if (x!=1) fac[++totcnt]=x,cnt[totcnt]=1;

    while(T--)
    {
        scanf("%lld%d%d%lld",&n,&n1,&n2,&m);
        for(int i=1;i<=n1+n2;i++)
            scanf("%lld",&a[i]);
        for(int i=1;i<=n2;i++)
            m-=a[n1+i]-1;
        m-=n;

        ll ans=0;
        for(int i=0;i<(1<<n1);i++)
        {
            int tot=0;
            ll x=m;
            for(int j=0;j<n1;j++)
                if (i&(1<<j))
                {
                    tot++;
                    x-=a[j+1];
                }
            if (tot%2) ans=(ans-exlucas(x+n-1,n-1)+p)%p;
            else ans=(ans+exlucas(x+n-1,n-1))%p;
        }
        printf("%lld\n",ans);
    }

    return 0;
}
posted @ 2018-04-30 21:10  Maxwei_wzj  阅读(168)  评论(0编辑  收藏  举报