【NOIP 校内模拟】T2 规避(容斥+最短路计数)

可以先不管符合条件的 先统计出所有的可能走法(最短路条数*最短路条数) 然后减去会相遇的

会相遇的分为在点相遇和在边相遇

在点(设为p)相遇:先保证点在最短路上 然后从s到p的最短路等于从t到p的最短路

在边(设为(x,y,z))相遇:同样需要保证边在最短路上(需要判断三次 同样玄妙♂) 以及相遇的地方一定在边上(两条不同的最短路的两倍不超过总长 这个姿♂势可以记住) 挺玄妙的

#include<bits/stdc++.h>
#define mod 1000000007
#define N 100005
#define M 200005
#define INF 0x3f3f3f3f
#define ll long long
using namespace std;
template<class T>
inline void read(T &x)
{
	x=0; int f=1;
	static char ch=getchar();
	while((!isdigit(ch))&&ch!='-')	ch=getchar();
	if(ch=='-')	f=-1,ch=getchar();
	while(isdigit(ch))	x=x*10+ch-'0',ch=getchar();
	x*=f;
}
struct Edge
{
	int from,to,next;
	ll val;
}edge[2*M],res[2*M];
int n,m,first[N],tot,s,t;
inline void addedge(int x,int y,int z)
{
	tot++;
	edge[tot].from=x; edge[tot].to=y; edge[tot].next=first[x]; edge[tot].val=z; first[x]=tot;
}
typedef pair<ll,int> Pair;
ll cnt[N][3],dis[N][3];
int visit[N];
void dijkstra(int S,int f)
{
	memset(visit,0,sizeof(visit));
	priority_queue<Pair,vector<Pair>,greater<Pair> > heap;
	heap.push(make_pair(0,S)); dis[S][f]=0; cnt[S][f]=1;
	while(!heap.empty())
	{
		int now=heap.top().second;
		heap.pop();
		if(visit[now])	continue;
		visit[now]=1;
		for(int u=first[now];u;u=edge[u].next)
		{
			int vis=edge[u].to;
			if(dis[now][f]+edge[u].val<dis[vis][f])
			{
				cnt[vis][f]=cnt[now][f];
				dis[vis][f]=dis[now][f]+edge[u].val;
				heap.push(make_pair(dis[vis][f],vis));
			}
			else if(dis[now][f]+edge[u].val==dis[vis][f])
				cnt[vis][f]=(cnt[vis][f]+cnt[now][f])%mod;
		}
	}
}
int main()
{
//	freopen("evade.in","r",stdin);
	read(n),read(m),read(s),read(t);
	for(int i=1,x,y,z;i<=m;i++)	
	{
		read(x),read(y),read(z);
		addedge(x,y,z); addedge(y,x,z);
	}
	memset(dis,0x3f,sizeof(dis));
	dijkstra(s,1); dijkstra(t,2);
	ll ans=cnt[t][1]*cnt[t][1]%mod,len=dis[t][1];	//总方案数
	for(int i=1;i<=n;i++)	//先判断点
	{
		if(dis[i][1]+dis[i][2]!=len)	continue;	//不在最短路上 
		if(dis[i][1]==dis[i][2])
		{
			ans=((ans-cnt[i][1]*cnt[i][1]%mod*cnt[i][2]*cnt[i][2]%mod)%mod+mod)%mod;
			continue;
		}
		if(dis[i][1]*2>=len)	continue;	//保证在边上 
		for(int u=first[i];u;u=edge[u].next)
		{
			int vis=edge[u].to;
			if(dis[vis][2]*2>=len)	continue;
			if(len!=dis[i][1]+edge[u].val+dis[vis][2])	continue;    //不在最短路上
			ans=((ans-cnt[i][1]*cnt[vis][2]%mod*cnt[i][1]*cnt[vis][2]%mod)%mod+mod)%mod;		
		}
	}
	cout<<(ans%mod+mod)%mod;
	return 0;
}
posted @ 2018-11-05 16:50  Patrickpwq  阅读(132)  评论(0编辑  收藏  举报