【BZOJ 1016】【JSOI 2008】最小生成树计数

http://www.lydsy.com/JudgeOnline/problem.php?id=1016
统计每一个边权在最小生成树中使用的次数,这个次数在任何一个最小生成树中都是固定的(归纳证明)。
在同一个边权上对所有边权为这个的边暴力统计(可以用矩阵树定理),然后用并查集把这个边权的所有边贡献的连通性都加上,再统计下一个边权。
最后把答案乘起来。

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N = 103;
const int M = 1003;
const int p = 31011;

struct Edge {
    int u, v, e;
    bool operator < (const Edge &A) const {
        return e < A.e;
    }
} E[M];
int fa[N], n, m, sz[N], val[N], tot[N], l[N], r[N];

int find(int x) {return fa[x] == x ? x : find(fa[x]);}

void merge(int x, int y) {
    fa[x] = y; sz[y] += sz[x];
    while (fa[y] != y) {
        y = fa[y];
        sz[y] += sz[x];
    }
}

void cut(int x, int y) {
    if (fa[x] == y) {
        fa[x] = x;
        sz[y] -= sz[x];
        while (fa[y] != y) {
            y = fa[y];
            sz[y] -= sz[x];
        }
    } else {
        fa[y] = y;
        sz[x] -= sz[y];
        while (fa[x] != x) {
            x = fa[x];
            sz[x] -= sz[y];
        }
    }
}

int dfsl, dfsr, dfstot, sum;

void dfs(int tmp, int nowtot) {
    if (nowtot == dfstot) {++sum; if (sum == p) sum = 0; return;}
    if (tmp > dfsr || dfstot - nowtot > dfsr - tmp + 1) return;
    dfs(tmp + 1, nowtot);
    int u = find(E[tmp].u), v = find(E[tmp].v);
    if (u != v) {
        if (sz[u] < sz[v]) merge(u, v); else merge(v, u);
        dfs(tmp + 1, nowtot + 1);
        cut(u, v);
    }
}

int in() {
    int k = 0; char c = getchar();
    for (; c < '0' || c > '9'; c = getchar());
    for (; c >= '0' && c <= '9'; c = getchar())
        k = k * 10 + c - 48;
    return k;
}

int main() {
    n = in(); m = in();
    int i;
    for (i = 1; i <= m; ++i) {E[i].u = in(); E[i].v = in(); E[i].e = in();}
    stable_sort(E + 1, E + m + 1);
    
    int x, y, num = 0, cnt = 0; val[0] = -1;
    for (i = 1; i <= n; ++i) fa[i] = i, sz[i] = 1;
    for (i = 1; i <= m; ++i) {
        x = find(E[i].u); y = find(E[i].v);
        if (E[i].e != val[num]) {
            r[num] = i - 1;
            val[++num] = E[i].e;
            l[num] = i;
        }
        if (x != y) {
            ++tot[num];
            if (sz[x] < sz[y]) merge(x, y); else merge(y, x);
            ++cnt;
            if (cnt == n - 1)
                break;
        }
    }
    if (cnt < n - 1) {puts("0"); return 0;}
    for (; i <= m && E[i].e == val[num]; ++i);
    r[num] = i - 1;
    
    for (i = 1; i <= n; ++i) fa[i] = i, sz[i] = 1;
    ll ans = 1;
    for (i = 1; i <= num; ++i) {
        sum = 0; dfsl = l[i]; dfsr = r[i]; dfstot = tot[i];
        dfs(dfsl, 0);
        for (int j = dfsl; j <= dfsr; ++j) {
            x = find(E[j].u); y = find(E[j].v);
            if (x != y) if (sz[x] < sz[y]) merge(x, y); else merge(y, x);
        }
        ans = ans * sum % p;
    }
    printf("%lld\n", ans);
    return 0;
}
posted @ 2016-12-20 08:24  abclzr  阅读(...)  评论(... 编辑 收藏