[思路题][LOJ2290][THUWC2017]随机二分图:状压DP+期望DP

分析

考虑状压DP,令\(f[sta]\)表示已匹配状态是\(sta\)\(0\)代表已匹配)时完美匹配的期望数量,显然\(f[0]=1\)

一条边出现了不代表它一定在完美匹配内,这也导致很难去直接利用题目中的边组来解决问题。

对于第二类边组,如果把两条边分开考虑(可以理解为把一个第二类的边组看成两个第一类的边组)。如果只有一条边出现在了完美匹配中,此时的贡献是\(50\%\),显然是正确的。如果两条边都出现在了完美匹配中,此时的贡献是\(50\% \times 50\% = 25\%\),但是根据第二类边组的定义,两条边都出现在完美匹配中的贡献应该也是\(50\%\)。所以我们可以再添加一个只包含一条边的边组,这里面的边比较特殊,其连接了这个第二类边组的四个结点,出现概率为\(25\%\),来补充不足的贡献。

第三类边组的处理方法类似,添加一个只包含一条边的边组,边连接了四个结点,出现概率为\(-25\%\),来消去多余的贡献。

为了减小时间复杂度,每次转移时要确保能把最高位的\(1\)异或掉。

代码

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <cctype>
#include <algorithm>
#include <map>
#define rin(i,a,b) for(int i=(a);i<=(b);i++)
#define rec(i,a,b) for(int i=(a);i>=(b);i--)
#define trav(i,a) for(int i=head[(a)];i;i=e[i].nxt)
typedef long long LL;
using std::cin;
using std::cout;
using std::endl;

inline int read(){
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
	return x*f;
}

const LL MOD=1e9+7,INV2=5e8+4,INV4=2.5e8+2;
int n,m,cnt,a[505];
LL p[505];
std::map<int,LL> mp;

LL dfs(int sta){
	if(!sta) return 1;
	if(mp.find(sta)!=mp.end()) return mp[sta];
	LL ret=0;
	rin(i,1,cnt){
		if((sta|a[i])==sta&&(a[i]<<1)>sta)
			ret=(ret+dfs(sta^a[i])*p[i])%MOD;
	}
	return mp[sta]=ret;
}

int main(){
	n=read(),m=read();
	rin(i,1,m){
		int typ=read();
		if(typ==0){
			int x=read(),y=read();
			a[++cnt]=((1<<(x-1))|(1<<(y+n-1)));
			p[cnt]=INV2;
		}
		else if(typ==1){
			int x1=read(),y1=read(),x2=read(),y2=read();
			int temp1=((1<<(x1-1))|(1<<(y1+n-1))),temp2=((1<<(x2-1))|(1<<(y2+n-1)));
			a[++cnt]=temp1,p[cnt]=INV2;
			a[++cnt]=temp2,p[cnt]=INV2;
			if(!(temp1&temp2)) a[++cnt]=(temp1|temp2),p[cnt]=INV4;
		}
		else{
			int x1=read(),y1=read(),x2=read(),y2=read();
			int temp1=((1<<(x1-1))|(1<<(y1+n-1))),temp2=((1<<(x2-1))|(1<<(y2+n-1)));
			a[++cnt]=temp1,p[cnt]=INV2;
			a[++cnt]=temp2,p[cnt]=INV2;
			if(!(temp1&temp2)) a[++cnt]=(temp1|temp2),p[cnt]=MOD-INV4;
		}
	}
	mp.clear();
	printf("%lld\n",(1<<n)*dfs((1<<(n<<1))-1)%MOD);
	return 0;
}

posted on 2018-12-27 08:59  ErkkiErkko  阅读(163)  评论(0编辑  收藏  举报