【BZOJ5287】[HNOI2018]毒瘤(动态规划,容斥)

【BZOJ5287】[HNOI2018]毒瘤(动态规划,容斥)

题面

BZOJ
洛谷

题解

考场上想到的暴力做法是容斥:
因为\(m-n\le 10\),所以最多会多出来\(11\)条非树边。
如果就是一棵树的话,显然答案就是独立集的个数。
非树边\(2^{11}\)枚举,强制非树边的两端同时备选导致不合法,容斥计算答案即可。
这样子的复杂度是\(O(2^{11}n)\),估算出来是\(2s\),然而在\(HNOI\)考场跑要\(20s\)(大雾

考虑如何优化这个东西。
我们\(2^{11}\)枚举出来之后,显然是强制令枚举的非树边的两端都被选入进集合。但是我们并不需要每次重新做一遍\(dp\),显然会出现大量的重复计算内容。
把枚举的点的虚树给构建出来,显然会影响到的部分只有虚树上的点和链。
对于每个虚树上的点,考虑修改后对于其虚树上父亲的影响。
因为\(dp\)状态是\(f[i][0/1]\),所以可以把关键点的状态设为\(x,y\),到虚树上父亲的链的转移全部用\(x,y\)的形式转移,这样子到其父亲时就可以合并一堆\(x,y\)的状态,当确定所有\(x,y\)后就能确定所有虚树上的关键点的\(dp\)值。
这样子单次容斥的复杂度就变成了虚树点数,这个东西很小。
这是一个很类似于动态\(dp\)的思路。

代码有点丑

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<vector>
using namespace std;
#define ll long long
#define MAX 100100
#define MOD 998244353
#define pb push_back
inline int read()
{
	int x=0;bool t=false;char ch=getchar();
	while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
	if(ch=='-')t=true,ch=getchar();
	while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
	return t?-x:x;
}
int fpow(int a,int b)
{
	int s=1;
	while(b){if(b&1)s=1ll*s*a%MOD;a=1ll*a*a%MOD;b>>=1;}
	return s;
}
int dsu[MAX];
int getf(int x){return x==dsu[x]?x:dsu[x]=getf(dsu[x]);}
struct Line{int v,next;}e[MAX<<1];
int h[MAX],cnt=1;
inline void Add(int u,int v){e[cnt]=(Line){v,h[u]};h[u]=cnt++;}
int n,m,ans,f[MAX][2],zr[MAX][2],g[MAX][2];
int fa[MAX],dfn[MAX],tim,size[MAX],hson[MAX],top[MAX],dep[MAX],low[MAX];
void dfs1(int u,int ff)
{
	f[u][0]=f[u][1]=1;fa[u]=ff;dep[u]=dep[ff]+1;size[u]=1;
	for(int i=h[u];i;i=e[i].next)
	{
		int v=e[i].v;if(v==ff)continue;
		dfs1(v,u);size[u]+=size[v];
		if(size[v]>size[hson[u]])hson[u]=v;
		if((f[v][0]+f[v][1])%MOD)f[u][0]=1ll*f[u][0]*(f[v][0]+f[v][1])%MOD;else zr[u][0]+=1;
		if(f[v][0])f[u][1]=1ll*f[u][1]*f[v][0]%MOD;else zr[u][1]+=1;
	}
}
void dfs2(int u,int tp)
{
	top[u]=tp;dfn[u]=++tim;
	if(hson[u])dfs2(hson[u],tp);
	for(int i=h[u];i;i=e[i].next)
		if(e[i].v!=fa[u]&&e[i].v!=hson[u])
			dfs2(e[i].v,e[i].v);
	low[u]=tim;
}
int LCA(int u,int v)
{
	while(top[u]^top[v])dep[top[u]]<dep[top[v]]?v=fa[top[v]]:u=fa[top[u]];
	return dep[u]<dep[v]?u:v;
}
bool cmp(int a,int b){return dfn[a]<dfn[b];}
int S[MAX],Top,snt;bool spn[MAX];
struct data{int x,y;}nt[50];
data operator*(data a,int b){return (data){1ll*a.x*b%MOD,1ll*a.y*b%MOD};}
data operator+(data a,data b){return (data){(a.x+b.x)%MOD,(a.y+b.y)%MOD};}
vector<int> fr[MAX];
vector<data> F0[MAX],F1[MAX];
int Q[MAX],tot;
int Div(int i,int p,int j)
{
	if(j)return zr[i][p]?0:1ll*f[i][p]*fpow(j,MOD-2)%MOD;
	else return zr[i][p]==1?f[i][p]:0;
}
void Calc(int x,int y)
{
	data f0=(data){1,0},f1=(data){0,1},ff0,ff1;
	int p=x;
	for(int i=fa[x],j=x;i!=y;p=j=i,i=fa[i])
	{
		int F0=Div(i,0,(f[j][0]+f[j][1])%MOD),F1=Div(i,1,f[j][0]);
		ff0=(f0+f1)*F0;ff1=f0*F1;
		f0=ff0;f1=ff1;
	}
	fr[y].pb(x);F0[y].pb(f0);F1[y].pb(f1);
	int a=(f[p][0]+f[p][1])%MOD,b=f[p][0];
	if(a)f[y][0]=1ll*f[y][0]*fpow(a,MOD-2)%MOD;else zr[y][0]-=1;
	if(b)f[y][1]=1ll*f[y][1]*fpow(b,MOD-2)%MOD;else zr[y][1]-=1;
	
}
bool Vis[MAX];
int DP()
{
	for(int i=Top;i;--i)g[S[i]][0]=zr[S[i]][0]?0:f[S[i]][0],g[S[i]][1]=zr[S[i]][1]?0:f[S[i]][1];
	for(int i=Top;i;--i)
		if(Vis[S[i]])g[S[i]][0]=0;
	for(int i=Top;i;--i)
		for(int j=0,l=fr[S[i]].size();j<l;++j)
		{
			int u=S[i],v=fr[u][j];
			data f0=F0[u][j],f1=F1[u][j];
			int ff0=(1ll*f0.x*g[v][0]+1ll*f0.y*g[v][1])%MOD;
			int ff1=(1ll*f1.x*g[v][0]+1ll*f1.y*g[v][1])%MOD;
			g[u][0]=1ll*g[u][0]*(ff0+ff1)%MOD;
			g[u][1]=1ll*g[u][1]*ff0%MOD;
		}
	return (g[1][0]+g[1][1])%MOD;
}
int main()
{
	n=read();m=read();
	for(int i=1;i<=n;++i)dsu[i]=i;
	for(int i=1;i<=m;++i)
	{
		int u=read(),v=read();
		if(getf(u)==getf(v))S[++Top]=u,S[++Top]=v,nt[snt++]=(data){u,v};
		else Add(u,v),Add(v,u),dsu[getf(u)]=getf(v);
	}
	dfs1(1,0);dfs2(1,1);S[++Top]=1;
	sort(&S[1],&S[Top+1],cmp);
	for(int i=Top;i>1;--i)S[++Top]=LCA(S[i],S[i-1]);
	sort(&S[1],&S[Top+1],cmp);Top=unique(&S[1],&S[Top+1])-S-1;
	for(int i=1;i<=Top;++i)spn[S[i]]=true;
	Q[tot=1]=S[1];
	for(int i=2;i<=Top;++i)
	{
		while(!(dfn[Q[tot]]<=dfn[S[i]]&&dfn[S[i]]<=low[Q[tot]]))--tot;
		Calc(S[i],Q[tot]);Q[++tot]=S[i];
	}
	for(int i=0;i<1<<snt;++i)
	{
		int d=1;
		for(int j=0;j<snt;++j)
			if(i&(1<<j))
				Vis[nt[j].x]=Vis[nt[j].y]=true,d^=1;
		int ret=DP();
		if(d)ans=(ans+ret)%MOD;
		else ans=(ans+MOD-ret)%MOD;
		for(int j=0;j<snt;++j)Vis[nt[j].x]=Vis[nt[j].y]=false;
	}
	printf("%d\n",ans);
	return 0;
}
posted @ 2019-02-13 11:14  小蒟蒻yyb  阅读(371)  评论(0编辑  收藏  举报