把博客园图标替换成自己的图标
把博客园图标替换成自己的图标end

【HDU6327】Random Sequence(记忆化搜索)

点此看题面

大致题意: 给你两个序列\(a,v\),其中\(a\)数组由\(0\sim m\)组成。随机用\(1\sim m\)中的一个数替换\(a\)中的\(0\),求\(\sum_{i=1}^{n-3}v_{gcd(a_i,a_{i+1},a_{i+2},a_{i+3})}\)的期望值。

记忆化搜索

考虑记忆化搜索,设\(f_{i,s_0,s_1,s_2}\)表示当前是第\(i\)位,前\(3,2,1\)个数的\(gcd\)分别是\(s_0,s_1,s_2\)时之后所有情况的元素乘积总和(求期望可以在记忆化搜索完后除以总方案数)。

则显然,设当前选择的数为\(t\),得到的值就是\(v_{gcd(s_0,t)}\cdot f_{i+1,gcd(s_1,t),gcd(s_2,t),t}\)

由于对于\(1\sim100\)以内的四个数\(a,b,c,d\),满足\(a|b,b|c,c|d\)的情况数是非常少的(据说只有\(1500\)个左右),所以是能过的。

注意最好把记忆化数组中使用过的位置存下来,方便清空。

代码

#include<bits/stdc++.h>
#define Tp template<typename Ty>
#define Ts template<typename Ty,typename... Ar>
#define Reg register
#define RI Reg int
#define Con const
#define CI Con int&
#define I inline
#define W while
#define N 100
#define X 1000000007
#define Inc(x,y) ((x+=(y))>=X&&(x-=X))
#define Qinv(x) Qpow(x,X-2)
using namespace std;
int n,m,a[N+5],v[N+5],gcd[N+5][N+5];
I int Qpow(RI x,RI y) {RI t=1;W(y) y&1&&(t=1LL*t*x%X),x=1LL*x*x%X,y>>=1;return t;}
I int Gcd(CI x,CI y) {return y?Gcd(y,x%y):x;}
class MemorizedSearcher//记忆化搜索
{
	private:
		#define pb push_back
		struct data
		{
			int id,s[3];I data(CI p=0,CI x=0,CI y=0,CI z=0):id(p),s({x,y,z}){}
			I data operator + (CI y) Con {return data(id+1,gcd[s[1]][y],gcd[s[2]][y],y);}
		};
		int f[N+5][N+5][N+5][N+5];vector<data> vis;
		I int dfs(Con data& x)//搜索
		{
			#define DFS(t) (1LL*(x.id>3?v[gcd[x.s[0]][t]]:1)*dfs(x+t)%X)//下一个状态
			#define F(x) f[x.id][x.s[0]][x.s[1]][x.s[2]]//当前记忆化数组
			if(x.id>n) return 1;if(F(x)) return F(x);vis.pb(x);//判断边界和已访问,开vector存储记忆化数组中使用过的位置
			if(a[x.id]) return F(x)=DFS(a[x.id]);//如果已给定数字
			for(RI i=1;i<=m;++i) Inc(F(x),DFS(i));return F(x);//枚举数字进行搜索
		}
	public:
		I void Solve()
		{
			RI i,t=0;for(i=1;i<=n;++i) t+=!a[i];printf("%d\n",1LL*dfs(1)*Qinv(Qpow(m,t))%X);//求解并输出答案
			for(t=vis.size(),i=0;i^t;++i) F(vis[i])=0;vis.clear();//清空
		}
}M;
int main()
{
	RI Tt,i,j;for(i=1;i<=N;++i) for(j=1;j<=N;++j) gcd[i][j]=Gcd(i,j);//初始化gcd
	scanf("%d",&Tt);W(Tt--)
	{
		for(scanf("%d%d",&n,&m),i=1;i<=n;++i) scanf("%d",a+i);
		for(i=1;i<=m;++i) scanf("%d",v+i);M.Solve();
	}return 0;
}
posted @ 2019-07-16 19:18  TheLostWeak  阅读(407)  评论(0编辑  收藏  举报