[JSOI2008]最小生成树计数
JSOI2008 最小生成树计数
阴间消消乐(躺
先给代码,之后补题解
#include<cstdio>
#include<iostream>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
typedef long long ll;
const int maxn=1010;
const ll mod=10337;
const ll mod1=31011;
struct edge{
int from,to,dis;
}g[maxn<<1];
bool cmp(edge a,edge b)
{
return a.dis<b.dis;
}
int inv[mod<<1];
int n,m;
int _max=0;
struct BCJ{
int f[maxn];
void init()
{
for(int i=1;i<maxn;i++)f[i]=i;
}
int gf(int x)
{
if(f[x]!=x)f[x]=gf(f[x]);
return f[x];
}
bool merge(int x,int y)
{
x=gf(x),y=gf(y);
if(x==y)return 0;
f[y]=x;
return 1;
}
}s1,s2;
ll ans=1;
ll a[120][120];
bool mark[maxn];
int id[maxn];
ll HLS(int n)
{
n--;
int i,j,k,l;
int pos=1;
int s=1;
bool fu=0;
for(i=1;i<=n;i++,pos++)
{
int _max=i;
for(j=i+1;j<=n;j++)
{
if(abs(a[j][pos])>abs(a[_max][pos]))_max=j;
}
if(!a[_max][pos])continue;
if(_max!=i)
{
for(j=pos;j<=n+1;j++)swap(a[i][j],a[_max][j]);
fu^=1;
}
for(j=i+1;j<=n;j++)
{
if(a[j][pos]){
int ori=a[j][pos];
s=(s*a[i][pos])%mod;
for(k=pos;k<=n+1;k++)
{
a[j][k]=a[j][k]*a[i][pos]-ori*a[i][k];
}
}
}
}
ll res=1;
if(s<0)res*=-1;
if(fu)res*=-1;
for(i=1;i<=n;i++)res=(res*a[i][i])%mod;
// printf("---%d\n",s);
return res*inv[abs(s)]%mod;
}
bool use[maxn];
ll col(int val)
{
memset(mark,0,sizeof(mark));
memset(a,0,sizeof(a));
int i,j;
s1.init();
vector<edge>q;
for(i=1;g[i].dis<=_max&&i<=m;i++)
{
if(g[i].dis!=val){
if(use[i])
s1.merge(g[i].from,g[i].to);
}
else q.push_back(g[i]);
}
int num=0;
for(i=1;i<=n;i++)
{
if(!mark[s1.gf(i)])id[s1.gf(i)]=++num,mark[s1.gf(i)]=1;
}
for(i=0;i<q.size();i++)
{
int f=q[i].from,t=q[i].to;
f=id[s1.gf(f)],t=id[s1.gf(t)];
// printf("--%d %d\n",f,t);
a[f][f]+=1;
a[t][t]+=1;
a[f][t]-=1;
a[t][f]-=1;
}
// printf("--%d\n",_max);
return HLS(num);
}
int main()
{
int i,j;
inv[1]=1;
for(i=2;i<mod;i++)inv[i]=(mod-mod/i)*inv[mod%i]%mod;
scanf("%d%d",&n,&m);
s1.init();
for(i=1;i<=m;i++)
{
scanf("%d%d%d",&g[i].from,&g[i].to,&g[i].dis);
}
sort(g+1,g+m+1,cmp);
int k=n-1;
i=1;
while(k&&(i<=m)){
if(s1.merge(g[i].from,g[i].to)){
use[i]=1;
_max=g[i].dis;
k--;
}
i++;
}
if(k){
printf("0\n");
return 0;
}
for(i=1;i<=m;i++)
{
if(g[i].dis>_max)break;
if(g[i].dis==g[i-1].dis)continue;
// cout<<g[i].dis<<endl;
ans=(ans*col(g[i].dis))%mod1;
}
printf("%lld\n",ans);
}