【BZOJ3640】JC的小苹果(高斯消元)

点此看题面

  • 一张\(n\)个点\(m\)条边的无向图,要从\(1\)号点走到\(n\)号点,初始体力为\(hp\)
  • 每当你走到编号为\(i\)的点时,体力都会失去\(a_i\),然后等概率选择当前点的一条边走出去。
  • 当体力值小于等于\(0\)的时候就失败了,求走到\(n\)号点的概率。
  • \(n\le150,m\le5\times10^3,hp\le10^4,0\le a_i\le hp\)

成环的概率\(DP\)

\(f_{k,i}\)表示在体力值为\(k\)时走到\(i\)的概率,显然有转移方程:

\[f_{k,i}=\sum_{j=1}^{n-1}f_{k+a_i,j}\times \frac{w_{i,j}}{deg_j} \]

其中\(w_{i,j}\)表示\(i,j\)之间的边数,注意此题有重边有自环。

看起来这样就完事了,但\(a_i\)可能等于\(0\),也就是说我们不能简单地分层\(DP\),因为在\(k\)相同的状态之间也可能存在转移。

这种时候就要套路地想到高斯消元,把这个转移式看作一个方程。

然而,如果直接暴力这么去做显然会\(T\)飞,因此要考虑优化。

高斯消元的优化

考虑我们总共要做\(hp\)次高斯消元,但实际上每次高斯消元的系数都是相同的,区别只在于等号右边的值。

因此我们可以在一开始先做一遍高斯消元,预处理出\(p_{i,j}\)表示第\(j\)个式子等号右边的值对第\(i\)个式子等号右边的值的贡献系数。

那么接下来每次求解就变成\(O(n^2)\)了。

代码:\(O(n^3+n^2hp)\)

#include<bits/stdc++.h>
#define Tp template<typename Ty>
#define Ts template<typename Ty,typename... Ar>
#define Reg register
#define RI Reg int
#define Con const
#define CI Con int&
#define I inline
#define W while
#define N 150
#define M 5000
#define HP 10000
#define DB double
#define eps 1e-12
#define add(x,y) (e[++ee].nxt=lnk[x],e[lnk[x]=ee].to=y)
using namespace std;
int n,m,hp,a[N+5],d[N+5],w[N+5][N+5];DB f[HP+5][N+5];
namespace Gauss//高斯消元
{
	DB a[N+5][N+5],p[N+5][N+5],v[N+5],res[N+5];
	I void Add(CI i,CI j)//用第i行去消第j行
	{
		DB t=-a[j][i]/a[i][i];for(RI k=1;k<=n;++k) a[j][k]+=t*a[i][k],p[j][k]+=t*p[i][k];
	}
	I void Init()//初始化
	{
		RI i,j,k;DB t;for(i=1;i<=n;++i) for(p[i][i]=1,j=i+1;j<=n;++j) Add(i,j);//从上往下消成三角形
		for(i=n;i;--i) {for(j=1;j<=n;++j) p[i][j]/=a[i][i];for(a[i][i]=1,j=i-1;j;--j) Add(i,j);}//从下往上消得只剩对角线
	}
	I void Solve()//快速求解
	{
		RI i,j;for(i=1;i<=n;++i) for(j=1;j<=n;++j) res[i]+=p[i][j]*v[j];//根据预处理出的系数计算
	}
}
int main()
{
	RI i,j,k,x,y;for(scanf("%d%d%d",&n,&m,&hp),i=1;i<=n;++i) scanf("%d",a+i);
	for(i=1;i<=m;++i) scanf("%d%d",&x,&y),++w[x][y],++d[x],x^y&&(++w[y][x],++d[y]);//注意自环
	for(i=1;i<=n;++Gauss::a[i][i],++i) if(!a[i]) for(j=1;j^n;++j) Gauss::a[i][j]=-1.0*w[i][j]/d[j];//求出系数矩阵
	DB t=0;for(Gauss::Init(),Gauss::v[1]=1,k=hp;k;--k)
	{
		for(i=1;i<=n;++i) if(a[i]&&k+a[i]<=hp) for(j=1;j^n;++j) Gauss::v[i]+=f[k+a[i]][j]*w[i][j]/d[j];//求出等号右边的值
		for(Gauss::Solve(),i=1;i<=n;++i) f[k][i]=Gauss::res[i],Gauss::v[i]=Gauss::res[i]=0;t+=f[k][n];//把值移到DP数组里
	}return printf("%.8lf\n",t),0;//输出答案
}
posted @ 2021-01-12 21:03  TheLostWeak  阅读(23)  评论(0编辑  收藏