容斥

牛客09 F (优化DP状态)

有一个思路比较顺的DP:按照num一个个考虑,记录当前每个位置是否已填,以及B数组剩余的状态;然后枚举新增的位置,判断合法性转移。容易发现B数组剩余的状态数不多,手模几个情况卡不到100(实际上最多72);瓶颈在于转移,如果直接枚举新增的集合,总复杂度是\(O(72m3^n)\),可以用每次新增一个的方法优化枚举子集,但这样得记录新增的个数以及上一个新增的位置,总复杂度是\(O(72mn^22^n)\),都难以通过。
发现按照上面的思路已经难以优化了,即同时记录72和\(2^n\)过于暴力了!考虑优化这个\(2^n\):我们要记这么多,为的就是满足每个位置恰好只填一个数这个比较严格的限制;能否将其放宽?考虑容斥!由于把num全部考虑完后,总共填了n次,所以只需考虑每个位置都填过或者每个位置被填不超过一次这两个等价表达中的一个,而容斥的话,肯定是对前者容斥更容易,即枚举至少没有被填过的位置集合,问题就变成在一个位置子集内填数,那DP的时候就只要记录72个状态,用组合数算算转移即可!总复杂度是\(O(72mn2^n)\)

点击查看代码
#include "bits/stdc++.h"
using namespace std;
const int P=1e9+7;
void inc(int& x,int y){
    x+=y;
    if(x>=P) x-=P;
}
int prd(int x,int y){
    return 1ll*x*y%P;
}
const int N=35,M=105;
int c[N][N];
int C(int n,int m){
    if(m<0 || m>n) return 0;
    return c[n][m];
}
int n,m,a[N],cnt[N],tot;
bool A[N][N],B[N][N];
int idt,tt;
struct edge{
    int v,d;
};
vector<edge>to[M];
map< vector<int>,int>mp;
void dfs(vector<int>res,int pos){
    mp[res]=++idt;
    int id=idt;
    bool pd=0;
    for(auto u:res) if(u) pd=1;
    if(!pd) tt=id;
    for(int j=pos+1;j<tot;j++) if(res[j]){
        res[j]--;
        dfs(res,j);
        res[j]++;
    }
    if(res[pos]){
        res[pos]--;
        dfs(res,pos);
        res[pos]++;
    }
}
void add(vector<int>res,int pos){
    for(int j=pos+1;j<tot;j++) if(res[j]){
        res[j]--;
        add(res,j);
        res[j]++;
    }
    if(res[pos]){
        res[pos]--;
        add(res,pos);
        res[pos]++;
    }
    //for(auto u:res) cout<<u<<" ";
    int id=mp[res];
    //cout<<"pos="<<pos<<" id="<<id<<endl;
    for(int i=0;i<tot;i++) if(res[i]){
        res[i]--;
        to[id].push_back((edge){mp[res],i+1});
        //cout<<id<<" "<<mp[res]<<" "<<i+1<<endl;
        res[i]++;
    }
    to[id].push_back((edge){id,0});
}
int pw[N],c1[1<<16];
void solve() {
    memset(A,0,sizeof(A));
    memset(B,0,sizeof(B));
    for(int i=0;i<M;i++) to[i].clear();
    int x,y;
    cin>>n>>m>>x>>y;
    int lst=0,c0=0;
    tot=0;
    for(int u,i=1;i<=m;i++){
        scanf("%d",&u);
        if(!u){
            c0++;
            continue;
        }
        if(u==lst) cnt[tot]++;
        else a[++tot]=u,cnt[tot]=1,lst=u;
    }
    //cout<<"tot="<<tot<<endl;
    //for(int i=1;i<=tot;i++) cout<<a[i]<<" "<<cnt[i]<<endl;

    while(x--){
        int u,v;
        cin>>u>>v;
        A[u][v]=1;
    }
    while(y--){
        int u,v;
        cin>>u>>v;
        B[u][v]=1;
    }

    tt=idt=0; mp.clear();
    vector<int>rest;
    for(int i=1;i<=tot;i++) rest.push_back(cnt[i]);
    dfs(rest,0);
    add(rest,0);
    int ans=0;
    for(int s2=1;s2<pw[n];s2++){
        int f[N][M]={0};
        f[0][1]=1;
        for(int i=1;i<=m;i++){ //num
            int vld=0;
            for(int j=1;j<=n;j++) if(s2&pw[j-1] && !A[j][i]) vld++;
            //cout<<"i="<<i<<" vld="<<vld<<endl;
            for(int s1=1;s1<=idt;s1++) if(f[i-1][s1]){
                //cout<<"s1="<<s1<<endl;
                for(auto u:to[s1]) if(!B[i][a[u.d]]){
                    //cout<<"d="<<u.d<<" v="<<u.v<<" a="<<a[u.d]<<endl;
                    inc(f[i][u.v],prd(f[i-1][s1],C(vld,a[u.d])) );
                }
            }
        }
        int op=1;
        for(int i=0;i<n;i++) if(!(s2&pw[i])) op=P-op;
        inc(ans,prd(op,f[m][tt]));
    }
    cout<<ans<<endl;
}

signed main() {
//	freopen("in","r",stdin);
//	freopen("out","w",stdout);

    for(int i=0;i<=16;i++){
        pw[i]=(1<<i);
        c[i][0]=1;
        if(!i) continue;
        for(int j=1;j<=i;j++) c[i][j]=c[i-1][j-1],inc(c[i][j],c[i-1][j]);
    }
    for(int i=0;i<pw[16];i++){
        for(int j=0;j<16;j++){
            if(pw[j]&i) c1[i]++;
        }
    }
	int T;cin >> T;while( T-- ) solve();
//	solve();
}

posted @ 2023-08-16 18:06  sz[sz]  阅读(7)  评论(0编辑  收藏  举报