【BZOJ 1016】【JSOI 2008】最小生成树计数

http://www.lydsy.com/JudgeOnline/problem.php?id=1016
统计每一个边权在最小生成树中使用的次数,这个次数在任何一个最小生成树中都是固定的(归纳证明)。
在同一个边权上对所有边权为这个的边暴力统计(可以用矩阵树定理),然后用并查集把这个边权的所有边贡献的连通性都加上,再统计下一个边权。
最后把答案乘起来。

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N = 103;
const int M = 1003;
const int p = 31011;

struct Edge {
	int u, v, e;
	bool operator < (const Edge &A) const {
		return e < A.e;
	}
} E[M];
int fa[N], n, m, sz[N], val[N], tot[N], l[N], r[N];

int find(int x) {return fa[x] == x ? x : find(fa[x]);}

void merge(int x, int y) {
	fa[x] = y; sz[y] += sz[x];
	while (fa[y] != y) {
		y = fa[y];
		sz[y] += sz[x];
	}
}

void cut(int x, int y) {
	if (fa[x] == y) {
		fa[x] = x;
		sz[y] -= sz[x];
		while (fa[y] != y) {
			y = fa[y];
			sz[y] -= sz[x];
		}
	} else {
		fa[y] = y;
		sz[x] -= sz[y];
		while (fa[x] != x) {
			x = fa[x];
			sz[x] -= sz[y];
		}
	}
}

int dfsl, dfsr, dfstot, sum;

void dfs(int tmp, int nowtot) {
	if (nowtot == dfstot) {++sum; if (sum == p) sum = 0; return;}
	if (tmp > dfsr || dfstot - nowtot > dfsr - tmp + 1) return;
	dfs(tmp + 1, nowtot);
	int u = find(E[tmp].u), v = find(E[tmp].v);
	if (u != v) {
		if (sz[u] < sz[v]) merge(u, v); else merge(v, u);
		dfs(tmp + 1, nowtot + 1);
		cut(u, v);
	}
}

int in() {
	int k = 0; char c = getchar();
	for (; c < '0' || c > '9'; c = getchar());
	for (; c >= '0' && c <= '9'; c = getchar())
		k = k * 10 + c - 48;
	return k;
}

int main() {
	n = in(); m = in();
	int i;
	for (i = 1; i <= m; ++i) {E[i].u = in(); E[i].v = in(); E[i].e = in();}
	stable_sort(E + 1, E + m + 1);
	
	int x, y, num = 0, cnt = 0; val[0] = -1;
	for (i = 1; i <= n; ++i) fa[i] = i, sz[i] = 1;
	for (i = 1; i <= m; ++i) {
		x = find(E[i].u); y = find(E[i].v);
		if (E[i].e != val[num]) {
			r[num] = i - 1;
			val[++num] = E[i].e;
			l[num] = i;
		}
		if (x != y) {
			++tot[num];
			if (sz[x] < sz[y]) merge(x, y); else merge(y, x);
			++cnt;
			if (cnt == n - 1)
				break;
		}
	}
	if (cnt < n - 1) {puts("0"); return 0;}
	for (; i <= m && E[i].e == val[num]; ++i);
	r[num] = i - 1;
	
	for (i = 1; i <= n; ++i) fa[i] = i, sz[i] = 1;
	ll ans = 1;
	for (i = 1; i <= num; ++i) {
		sum = 0; dfsl = l[i]; dfsr = r[i]; dfstot = tot[i];
		dfs(dfsl, 0);
		for (int j = dfsl; j <= dfsr; ++j) {
			x = find(E[j].u); y = find(E[j].v);
			if (x != y) if (sz[x] < sz[y]) merge(x, y); else merge(y, x);
		}
		ans = ans * sum % p;
	}
	printf("%lld\n", ans);
	return 0;
}
posted @ 2016-12-20 08:24  abclzr  阅读(115)  评论(0编辑  收藏