「HNOI2018」毒瘤

「HNOI2018」毒瘤

解题思路

先考虑只有一棵树的情况,经典独立集计数。

\[dp[u][0]=\prod (dp[v][0]+dp[v][1]) \\ dp[u][1]=\prod dp[v][0] \]

然后考虑将所有非树边的端点建一棵虚树,那么虚树以外的节点的 \(\text{dp}\) 值是不会改变的,那么就可以推出虚树上一个节点对它父亲贡献的系数。

然后枚举一下所有非树边能选取的合法状态,再在虚树上计算一遍贡献,令 \(k = m-n+1\),这样复杂度是 \(\mathcal O(k3^k+m)\)

事实上只需要枚举每一条非树边的左端点是否选,当左端点选的时候,右端点只能不选,否则右端点可选可不选,这样涵盖了所有三种合法情况,复杂度 \(\mathcal O(k2^k+m)\)


code

/*program by mangoyang*/
#include <bits/stdc++.h>
#define inf (0x7f7f7f7f)
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
typedef long long ll;
using namespace std;
template <class T>
inline void read(T &x){
	int ch = 0, f = 0; x = 0;
	for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
	for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
	if(f) x = -x;
} 
#define int ll
const int N = 300005, mod = 998244353;
vector<pair<int, int> > ed;
vector<int> g[N], e[N], vec;
int dfn[N], dep[N], fa[N], s[N][2], dp[N][2], dp2[N][2], ali[N][2];
int tr[N][2][2], a[N], b[N], in[N], pa[N], st[N], n, m, ans;
inline int ask(int x){ 
	if(x == fa[x]) return x; else return fa[x] = ask(fa[x]);  
}
namespace PR{
	int Log[N], f[N][22], tot;
	inline int chkmin(int x, int y){ return dep[x] < dep[y] ? x : y; }
	inline void dfs(int u, int fa){
		dep[u] = dep[fa] + 1, dfn[u] = ++tot, f[tot][0] = u;
		dp[u][0] = dp[u][1] = 1;
		for(int i = 0; i < (int) g[u].size(); i++){
			int v = g[u][i];
			if(v == fa) continue; 
			dfs(v, u), f[++tot][0] = u;
			(dp[u][1] *= dp[v][0]) %= mod;
			(dp[u][0] *= (dp[v][1] + dp[v][0]) % mod) %= mod;
		}
	}
	inline void solve(){
		dfs(1, 0);
		for(int i = 2; i <= tot; i++) Log[i] = Log[i>>1] + 1;
		for(int j = 1; j <= 21; j++)
			for(int i = 1; i + (1 << j) - 1 <= tot; i++)
				f[i][j] = chkmin(f[i][j-1], f[i+(1<<(j-1))][j-1]);
	}
	inline int Lca(int u, int v){
		int x = dfn[u], y = dfn[v];
		if(x > y) swap(x, y); int g = Log[y-x+1];
		return chkmin(f[x][g], f[y-(1<<g)+1][g]);
	}
}
inline bool cmp(int x, int y){ return dfn[x] < dfn[y]; }
inline void buildtree(int a[], int len){
	sort(a + 1, a + len + 1, cmp); int top = 0, tot = 0;
	for(int i = 1; i <= len; i++){
		int u = a[i];
		if(!top){ st[++top] = b[++tot] = u; continue; }
		int ca = PR::Lca(u, st[top]);
		for(; top > 1 && dep[st[top]] > dep[ca]; top--)
			if(dep[st[top-1]] < dep[ca]) pa[st[top]] = ca;
		if(st[top] != ca) 
			pa[ca] = st[top], st[++top] = b[++tot] = ca;
		pa[u] = ca, st[++top] = b[++tot] = u;
	}
	for(int i = 1; i <= tot; i++){
		in[b[i]] = 1, e[pa[b[i]]].push_back(b[i]);
		ali[b[i]][0] = ali[b[i]][1] = 1;
	}
}
inline int dfs(int u, int fa){
	int x = 0, k0 = 1, k1 = 1;
	for(int i = 0; i < (int) g[u].size(); i++){
		int v = g[u][i];
		if(v == fa) continue;
		int tmp = dfs(v, u);
		if(tmp) x = tmp;
		else (k0 *= dp[v][0]) %= mod, (k1 *= (dp[v][0] + dp[v][1]) % mod) %= mod;
	}
	s[u][0] = k1, s[u][1] = k0;
	if(in[u]) return tr[u][0][0] = 1, tr[u][1][1] = 1, u;
	if(!x) return 0;
	int tmp[2][2];
	for(int i = 0; i < 2; i++)
		for(int j = 0; j < 2; j++) tmp[i][j] = tr[x][i][j];
	tr[x][0][0] = (tmp[0][0] + tmp[0][1]) % mod * k1 % mod;
	tr[x][0][1] = tmp[0][0] * k0 % mod;
	tr[x][1][0] = (tmp[1][0] + tmp[1][1]) % mod * k1 % mod;
	tr[x][1][1] = tmp[1][0] * k0 % mod;
	return x;
}
inline void dfs2(int u){
	for(int i = 0; i < 2; i++) dp2[u][i] = ali[u][i] * s[u][i];
	for(int i = 0; i < (int) e[u].size(); i++){
		int v = e[u][i]; 
		dfs2(v);
		int k0 = (dp2[v][0] * tr[v][0][0] % mod + dp2[v][1] * tr[v][1][0] % mod) % mod;
		int k1 = (dp2[v][0] * tr[v][0][1] % mod + dp2[v][1] * tr[v][1][1] % mod) % mod;
		(dp2[u][0] *= (k0 + k1) % mod) %= mod, (dp2[u][1] *= k0) %= mod; 
	}
}
inline void solve(int mask){
	for(int i = 0; i < (int) vec.size(); i++)
		ali[vec[i]][0] = ali[vec[i]][1] = 1;
	for(int i = 0; i < (int) ed.size(); i++){
		int x = ed[i].first, y = ed[i].second;
		int tmp = (1 << i) & mask;
		if(tmp) ali[vec[x]][0] = ali[vec[y]][1] = 0; else ali[vec[x]][1] = 0;
	}
	dfs2(1), (ans += dp2[1][0] + dp2[1][1]) %= mod;
}
signed main(){
	int len = 0;
	read(n), read(m);
	for(int i = 1; i <= n; i++) fa[i] = i;
	for(int i = 1, x, y; i <= m; i++){
		read(x), read(y);
		if(ask(x) == ask(y)){	
			ed.push_back(make_pair(x, y));
			vec.push_back(x), vec.push_back(y);
		}
		else{
			fa[ask(x)] = ask(y);
			g[x].push_back(y), g[y].push_back(x);
			
		}
	}
	PR::solve();
	if(m == n - 1) return cout << (dp[1][0] + dp[1][1]) % mod, 0;
	sort(vec.begin(), vec.end());
	vector<int>::iterator newend = unique(vec.begin(), vec.end());
	vec.erase(newend, vec.end());
	for(int i = 0; i < (int) vec.size(); i++) a[++len] = vec[i];
	if(!len || vec[0] != 1) a[++len] = 1;
	buildtree(a, len), dfs(1, 0);
	for(int i = 0; i < (int) ed.size(); i++){
		int x = lower_bound(vec.begin(), vec.end(), ed[i].first) - vec.begin();
		int y = lower_bound(vec.begin(), vec.end(), ed[i].second) - vec.begin();
		ed[i] = make_pair(x, y);
	}
	for(int i = 0; i < (1 << (m - n + 1)); i++) solve(i);
	cout << ans << endl;
	return 0; 
}
posted @ 2019-03-17 19:51  Joyemang33  阅读(208)  评论(0编辑  收藏  举报