【bzoj5056】OI游戏 最短路+矩阵树定理

题目描述

给出一张无向图,求满足 0号点到所有点的路径长等于原图中它们之间最短路 的生成树的个数。

输入

第一行一个整数N,代表原图结点。
接下来N行,每行N个字符,描绘了一个邻接矩阵。邻接矩阵中,
如果某一个元素为0,代表这两个点之间不存在边,
并且保证第i行第i列的元素为0,第i行第j列的元素(i≠j)等于第j行第i列的元素。
2≤N≤50

输出

一行一个整数,代表删法总方案数膜1,000,000,007的结果。

样例输入

4
0123
1012
2101
3210

样例输出

6


题解

最短路+矩阵树定理

首先求出这张图以0为起点的最短路径图,即边只能在这些图中选择。这里由于数据范围小,随便哪种最短路都可以。代码中写了Floyd。

然后考虑一个点是从哪个个节点更新的:最短路径图上指向它的所有边都可以选择。(注意:最短路径图是有向的。即如果a能更新b则有a->b,不代表b能更新a。)

所以答案就是最短路径图上以0为根的生成外向树形图的数目。求 入度矩阵-邻接矩阵 删去0所在行列的行列式的值即为答案。

时间复杂度$O(n^3)$

UPD:我SB了。。。能直接求出的干嘛要用矩阵树定理。。。

由于最短路径图是一个DAG,因此相当于除了1号点以外,其它点选择其入边的任意一条均可(类似于归纳法),所以答案就是最短路径图中除了一号点以外其它所有的点的入度之积。

(转化为矩阵树定理,即如果按照拓扑序给点编号的话,相当于得到的矩阵是一个上三角矩阵,直接求对角线(入度)乘积即为答案)。

#include <cstdio>
#include <cstring>
#include <algorithm>
#define N 50
#define mod 1000000007
using namespace std;
typedef long long ll;
int map[N][N] , dis[N][N];
ll a[N][N];
char str[60];
inline ll pow(ll x , int y)
{
	ll ans = 1;
	while(y)
	{
		if(y & 1) ans = ans * x % mod;
		x = x * x % mod , y >>= 1;
	}
	return ans;
}
int main()
{
	int n , i , j , k , d = 0;
	ll t , ans = 1;
	scanf("%d" , &n);
	memset(map , 0x3f , sizeof(map)) , memset(dis , 0x3f , sizeof(dis));
	for(i = 0 ; i < n ; i ++ )
	{
		scanf("%s" , str) , dis[i][i] = 0;
		for(j = 0 ; j < n ; j ++ )
			if(str[j] != '0')
				map[i][j] = dis[i][j] = str[j] - '0';
	}
	for(k = 0 ; k < n ; k ++ )
		for(i = 0 ; i < n ; i ++ )
			for(j = 0 ; j < n ; j ++ )
				dis[i][j] = min(dis[i][j] , dis[i][k] + dis[k][j]);
	for(i = 0 ; i < n ; i ++ )
		for(j = 0 ; j < n ; j ++ )
			if(map[i][j] != 0x3f3f3f3f && dis[0][i] + map[i][j] == dis[0][j])
				a[j][j] ++ , a[i][j] = (a[i][j] - 1 + mod) % mod;
	for(i = 1 ; i < n ; i ++ )
	{
		for(j = i ; j < n ; j ++ )
			if(a[i][j])
				break;
		if(j == n) continue;
		if(j != i)
		{
			d ^= 1;
			for(k = i ; k < n ; k ++ ) swap(a[i][k] , a[j][k]);
		}
		ans = ans * a[i][i] % mod;
		for(t = pow(a[i][i] , mod - 2) , j = i ; j < n ; j ++ )
			a[i][j] = a[i][j] * t % mod;
		for(j = i + 1 ; j < n ; j ++ )
			for(t = a[j][i] , k = i ; k < n ; k ++ )
				a[j][k] = (a[j][k] - a[i][k] * t % mod + mod) % mod;
	}
	for(i = 1 ; i < n ; i ++ ) ans = ans * a[i][i] % mod;
	if(d) ans = (mod - ans) % mod;
	printf("%lld\n" , ans);
	return 0;
}

 

 

posted @ 2017-09-27 20:16  GXZlegend  阅读(361)  评论(0编辑  收藏  举报