BZOJ 1016 最小生成树计数(矩阵树定理)

我们把边从小到大排序,然后依次插入一种权值的边,然后把每一个联通块合并。
然后当一次插入的边不止一条时做矩阵树定理就行了。算出有多少种生成树就行了。
剩下的交给乘法原理。
实现一不小心就会让程序变得很丑


#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<algorithm>
using namespace std;
#define int long long
const int mod=31011;
const int N=110;
int fa[N],a[N][N][N],n,m,b[1010],cnt[N],id[N],w[N],ans[N],mmp[N];
struct edge{
    int u,v,w;
}e[1010];
bool cmp(edge a,edge b){
    return a.w<b.w;
}
int find(int x){
    if(fa[x]==x)return x;
    else return fa[x]=find(fa[x]);
}
int gauss(int x,int n){
    for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++)
            a[x][i][j]=(a[x][i][j]+mod)%mod;
    int f=1,ans=1;
    for(int i=1;i<=n;i++){
        for(int j=i+1;j<=n;j++){
            int A=a[x][i][i],B=a[x][j][i];
            while(B){
                int t=A/B;A%=B;swap(A,B);
                for(int k=i;k<=n;k++)a[x][i][k]=(a[x][i][k]-t*a[x][j][k]%mod+mod)%mod;
                for(int k=i;k<=n;k++)swap(a[x][i][k],a[x][j][k]);
                f=-f;
            }
        }
        ans=ans*a[x][i][i]%mod;
    }
    memset(a[x],0,sizeof(a[x]));
    return (ans*f+mod)%mod;;
}
int read(){
    int sum=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){sum=sum*10+ch-'0';ch=getchar();}
    return sum*f;
}
void init(){
    for(int j=1;j<=n;j++)cnt[j]=0;
    for(int j=1;j<=n;j++)fa[j]=j;
}
signed main(){
    n=read(),m=read();
    for(int i=1;i<=m;i++)e[i].u=read(),e[i].v=read(),e[i].w=read(),b[i]=e[i].w;
    sort(b+1,b+1+m);
    int num=unique(b+1,b+1+m)-b-1;
    for(int i=1;i<=m;i++)e[i].w=lower_bound(b+1,b+1+num,e[i].w)-b;
    sort(e+1,e+1+m,cmp);
    int now=1;
    for(int i=1;i<=n;i++)ans[i]=1;
    for(int i=1;i<=num;i++){
        init();
        int line=now,tmp=0;
        while(line<=m&&e[line].w==i){
            if(e[line].u==e[line].v){line++;continue;}
            int x=find(e[line].u),y=find(e[line].v);
            if(x!=y)fa[x]=y;
            line++;
        }
        for(int j=1;j<=n;j++)fa[j]=find(j);
        for(int j=1;j<=n;j++)id[j]=++cnt[fa[j]];
        for(int j=now;j<=line-1;j++){
            if(e[j].u==e[j].v)continue;
            a[fa[e[j].u]][id[e[j].u]][id[e[j].u]]++;
            a[fa[e[j].v]][id[e[j].v]][id[e[j].v]]++;
            a[fa[e[j].u]][id[e[j].u]][id[e[j].v]]--;
            a[fa[e[j].u]][id[e[j].v]][id[e[j].u]]--;
        }
        for(int j=1;j<=n;j++)w[j]=1;
        for(int j=1;j<=n;j++)w[fa[j]]=w[fa[j]]*ans[j]%mod;
        for(int j=1;j<=n;j++)
            if(cnt[j]){
                mmp[j]=++tmp;
                if(cnt[j]==1)ans[tmp]=w[j];
                else ans[tmp]=w[j]*gauss(j,cnt[j]-1)%mod;
            }
        for(int j=line;j<=m;j++)e[j].u=mmp[fa[e[j].u]],e[j].v=mmp[fa[e[j].v]];
        now=line;n=tmp;
    }
    if(n>1)printf("0");
    else printf("%lld",ans[1]);
    return 0;
}
posted @ 2019-03-05 21:33  Xu-daxia  阅读(...)  评论(... 编辑 收藏