DP 题合集

背包 DP

[NOIP 2018 提高组] 货币系统

注意到,\(b\) 必然是 \(a\) 的子集。\(dp_j\) 表示凑出 \(j\) 的方案数。

code
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
const int N=110,M=25010;
int t,n,a[N],dp[M],ans;
int main(){
	scanf("%d",&t);
	while(t--){
		ans=0;
		memset(dp,~0x3f,sizeof(dp));
		scanf("%d",&n);
		for(int i=1;i<=n;i++)
			scanf("%d",&a[i]);
		dp[0]=0;
		for(int i=1;i<=n;i++)
			for(int j=a[i];j<=25000;j++)
				dp[j]=max(dp[j],dp[j-a[i]]+1);
		for(int i=1;i<=n;i++)
			if(dp[a[i]]==1)//唯一一种表示方法
				ans++;
		printf("%d\n",ans);
	}
	return 0;
}

Arpa's weak amphitheater and Mehrdad's valuable Hoses

并查集+分组背包。
对于每个连通块分为一组,每组中有组内的所有人单独一个和组内所有人之和。

code
#include<iostream>
#include<vector>
using namespace std;
const int N=1010,M=1e5+10;
int n,m,W,fa[N],w[M],b[M],sumw[N],sumb[N],dp[N],cnt;
int find(int x){
	if(fa[x]==x)
		return x;
	return fa[x]=find(fa[x]);
}
bool vis[N];
vector<int> p[N];
int main(){
	ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
	cin>>n>>m>>W;
	cnt=n;
	for(int i=1;i<=n;i++)
		fa[i]=i;
	for(int i=1;i<=n;i++)
		cin>>w[i];
	for(int i=1;i<=n;i++)
		cin>>b[i];
	for(int i=1;i<=m;i++){
		int x,y;
		cin>>x>>y;
		int a=find(x),b=find(y);
		fa[a]=b;
	}
	for(int i=1;i<=n;i++){
		int x=find(i);
		p[x].push_back(i);
		sumw[x]+=w[i];
		sumb[x]+=b[i];
	}
	for(int i=1;i<=n;i++){
		if(p[i].size()>1){
			p[i].push_back(++cnt);
			w[cnt]=sumw[i];
			b[cnt]=sumb[i];
		}
	}
	for(int i=1;i<=n;i++){
		if(p[i].empty())
			continue;
		for(int j=W;j>=0;j--)
			for(int k:p[i])
				if(j>=w[k])
					dp[j]=max(dp[j],dp[j-w[k]]+b[k]);
	}
	cout<<dp[W];
	return 0;
}

Tak and Cards

\(dp_{k,j}\) 表示选择了 \(k\) 个物品,总和为 \(j\) 时的情况数,统计合法答案即可。

code
#include<cstdio>
template<typename T>
void read(T &x){
	bool f=0;
	x=0;
	char ch=getchar();
	while(ch<'0'||ch>'9'){
		if(ch=='-')
			f=1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9'){
		x=(x<<1)+(x<<3)+(ch^48);
		ch=getchar();
	}
	if(f)
		x=~x+1;
	return;
}
template<typename T1,typename...T2>
void read(T1 &x,T2 &...args){
	read(x);
	read(args...);
	return;
}
typedef long long ll;
const int N=55;
int n,a,x[N],sum;
ll dp[N][N*N],ans;
int main(){
	read(n,a);
	for(int i=1;i<=n;i++)
		read(x[i]),sum+=x[i];
	dp[0][0]=1;
	for(int i=1;i<=n;i++)
		for(int j=sum;j>=x[i];j--)
			for(int k=n;k>=1;k--)
				dp[k][j]+=dp[k-1][j-x[i]];
	for(int i=1;i<=n;i++)
		ans+=dp[i][a*i];
	printf("%lld\n",ans);
	return 0;
}

Max Sum Counting

\(a\) 升序排序,\(dp_j\) 表示 \(b\) 的和为 \(j\) 的情况数。

code
#include<cstdio>
#include<algorithm>
#define mod 998244353
template<typename T>
void read(T &x){
	bool f=0;
	x=0;
	char ch=getchar();
	while(ch<'0'||ch>'9'){
		if(ch=='-')
			f=1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9'){
		x=(x<<1)+(x<<3)+(ch^48);
		ch=getchar();
	}
	if(f)
		x=~x+1;
	return;
}
template<typename T1,typename...T2>
void read(T1 &x,T2 &...args){
	read(x);
	read(args...);
	return;
}
const int N=5010;
int n,dp[N],ans;
struct node{
	int a,b;
	bool operator<(const node &x)const{
		return a<x.a;
	}
}p[N];
int main(){
	read(n);
	for(int i=1;i<=n;i++)
		read(p[i].a);
	for(int i=1;i<=n;i++)
		read(p[i].b);
	std::sort(p+1,p+1+n);
	dp[0]=1;
	for(int i=1;i<=n;i++)
		for(int j=5000;j>=p[i].b;j--){
			dp[j]=(dp[j]+dp[j-p[i].b])%mod;
			if(p[i].a>=j)
				ans=(ans+dp[j-p[i].b])%mod;
		}
	printf("%d\n",ans);
	return 0;
}

消失之物

\(dp_{j,0}\) 表示容积为 \(j\) 时的方案数。\(dp_{j,1}\) 表示移除某物品后容积为 \(j\) 时的方案数。

code
#include<cstdio>
const int N=2010;
int n,m,w[N],dp[N][2];
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++){
		scanf("%d",w+i);
	}
	dp[0][0]=dp[0][1]=1;
	for(int i=1;i<=n;i++)
		for(int j=m;j>=w[i];j--)
			dp[j][0]=(dp[j][0]+dp[j-w[i]][0])%10;
	for(int i=1;i<=n;i++,printf("\n"))
		for(int j=1;j<=m;j++){
			if(j-w[i]>=0)
				dp[j][1]=(dp[j][0]-dp[j-w[i]][1]+10)%10;//移除物品i
			else
				dp[j][1]=dp[j][0];
			printf("%d",dp[j][1]);
		}
	return 0;
}

区间 DP

Pre-Order

\(dp_{l,r}\) 表示区间 \([l,r]\) 的方案数。枚举断点 \(k\)。若 \(p_l\le p_k\) 并且 \(p_{k+1}\le p_r\),说明合法,则 \(dp_{l,r}\) 就加上 \(dp_{l,k}\times dp_{k+1,r}\)。当前点的 dfn 大于之前的点就说明可以作为之前点的子树中的点。

code
#include<cstdio>
template<typename T>
void read(T &x){
	x=0;
	char ch=getchar();
	while(ch<'0'||ch>'9')
		ch=getchar();
	while(ch>='0'&&ch<='9'){
		x=(x<<1)+(x<<3)+(ch^48);
		ch=getchar();
	}
	return;
}
#define mod 998244353
const int N=510;
int n,p[N],dp[N][N];
int main(){
	read(n);
	for(int i=1;i<=n;i++)
		read(p[i]);
	for(int i=1;i<=n;i++)
		dp[i][i]=1;
	for(int len=2;len<=n;len++)
		for(int l=1,r=l+len-1;r<=n;l++,r++)
			for(int k=l;k<r;k++)
				if(p[l]<=p[k]&&p[k+1]<=p[r])
					dp[l][r]=(dp[l][r]+1ll*dp[l][k]*dp[k+1][r])%mod;
	printf("%d\n",dp[1][n]);
	return 0;
}

数据结构优化 DP

Linear Kingdom Races

\(dp_i\) 表示前 \(i\) 条道路的最大利润。
使用线段树维护区间 \([1,j]\) 的最大利润。

  • 不修复道路 \(i\)\(dp_i=dp_{i-1}\)
  • 修复道路 \(i\)\(dp_i=\max_{0 \le j<i}(dp_j+val(j+1,i)-cost(j+1,i))\)
code
#include<cstdio>
#include<vector>
using namespace std;
template<typename T>
void read(T &x){
	x=0;
	char ch=getchar();
	while(ch<'0'||ch>'9')
		ch=getchar();
	while(ch>='0'&&ch<='9'){
		x=(x<<1)+(x<<3)+(ch^48);
		ch=getchar();
	}
	return;
}
template<typename T1,typename...T2>
void read(T1 &x,T2 &...args){
	read(x);
	read(args...);
	return;
}
template<typename T>
T Max(const T &a,const T &b){
	return a<b?b:a;
}
typedef long long ll;
#define ls u<<1
#define rs u<<1|1
const int N=2e5+10;
int n,m,w[N];
ll tree[N*4],tag[N*4];
void push_up(int u){
	tree[u]=Max(tree[ls],tree[rs]);
	return;
}
void push_down(int u){
	tree[ls]+=tag[u],tree[rs]+=tag[u];
	tag[ls]+=tag[u],tag[rs]+=tag[u],tag[u]=0;
	return;
}
void modify(int u,int l,int r,int x,int y,ll k){
	if(x<=l&&y>=r){
		tree[u]+=k;
		tag[u]+=k;
		return;
	}
	push_down(u);
	int mid=(l+r)/2;
	if(x<=mid)
		modify(ls,l,mid,x,y,k);
	if(y>mid)
		modify(rs,mid+1,r,x,y,k);
	push_up(u);
	return;
}
ll query(int u,int l,int r,int x,int y){
	if(x<=l&&y>=r)
		return tree[u];
	ll res=0;
	int mid=(l+r)/2;
	push_down(u);
	if(x<=mid)
		res=Max(res,query(ls,l,mid,x,y));
	if(y>mid)
		res=Max(res,query(rs,mid+1,r,x,y));
	return res;
}
vector<pair<int,int>> a[N];
ll dp[N];
int main(){
	read(n,m);
	for(int i=1;i<=n;i++)
		read(w[i]);
	for(int i=1,l,r,p;i<=m;i++){
		read(l,r,p);
		a[r].push_back(make_pair(l,p));
	}
	for(int i=1;i<=n;i++){
		modify(1,0,n,0,i-1,-w[i]);
		for(auto x:a[i])
			modify(1,0,n,0,x.first-1,x.second);
		dp[i]=Max(dp[i-1],query(1,0,n,0,i-1));
		modify(1,0,n,i,i,dp[i]);
	}
	printf("%lld\n",dp[n]);
	return 0;
}

Digital Wallet

code
#include<cstdio>
#include<vector>
using namespace std;
template<typename T>
T Max(const T &a,const T &b){
    return a<b?b:a;
}
typedef long long ll;
const int N=1e5+10;
int n,m,k,a[15][N];
ll dp[N];
int main(){
    scanf("%d%d%d",&n,&m,&k);
    for(int i=1;i<=n;i++)
        for(int j=1;j<=m;j++)
            scanf("%d",&a[i][j]);
    for(int i=1;i<=m;i++)
        for(int j=1;j<=n;j++)
            for(int t=i;t>=Max(i-k+1,1);t--)
                dp[t]=Max(dp[t],dp[t-1]+a[j][i]);
    printf("%lld\n",dp[m-k+1]);
    return 0;
}

状压 DP

[CCO 2015] 路短最

状压 DP 的特点一般就是某个数据范围小于 \(20\)。比如本题的 \(n\),可以直接将点状压表示访问过的城市。转移比较平凡。

code
#include<cstdio>
#include<cstring>
template<typename T>
void read(T &x){
	x=0;
	char ch=getchar();
	while(ch<'0'||ch>'9')
		ch=getchar();
	while(ch>='0'&&ch<='9'){
		x=(x<<1)+(x<<3)+(ch^48);
		ch=getchar();
	}
	return;
}
template<typename T,typename...Args>
void read(T &x,Args &...args){
	read(x);
	read(args...);
	return;
}
template<typename T>
T Max(const T &a,const T &b){
	return a<b?b:a;
}
const int N=20;
int n,m,dp[N][1<<18|1],ans,e[N][N];
int main(){
	memset(dp,0xcf,sizeof(dp));
    read(n,m);
	for(int i=1,s,d,l;i<=m;i++){
		read(s,d,l);
		e[s][d]=l;
	}
	dp[0][1]=0;
	for(int i=1;i<(1<<n);i+=2)//小优化,0作为起点必然被访问过
		for(int j=0;j<n;j++)
			if(i>>j&1)
				for(int k=1;k<n;k++){
					if((i>>k&1)&&e[j][k])
						dp[k][i]=Max(dp[k][i],dp[j][1<<k^i]+e[j][k]);
					if(k==n-1)
						ans=Max(ans,dp[k][i]);
				}
	printf("%d\n",ans);
    return 0;
}

[GDOI2014] 拯救莫莉斯

注意到 \(m\le 7\)。维护 \(dp_{i,j,k}\) 表示第 \(i\) 行,第 \(i-1\) 行状态为 \(j\),第 \(i\) 行状态为 \(k\)。可以很容易地同时维护最小代价和最小油库个数。预处理每种状态的代价和油库个数。
转移时,枚举第 \(i-2\),第 \(i-1\) 和第 \(i\) 行状态,但仅需判断第 \(i-1\) 行合法即可转移。在最终统计答案时再判第 \(n\) 行合法。

code
#include<cstdio>
#include<cstring>
template<typename T>
void read(T &x){
    x=0;
	char ch=getchar();
	while(ch<'0'||ch>'9')
		ch=getchar();
	while(ch>='0'&&ch<='9'){
		x=(x<<1)+(x<<3)+(ch^48);
		ch=getchar();
	}
    return;
}
template<typename T,typename...Args>
void read(T &x,Args &...args){
    read(x);
    read(args...);
    return;
}
int getbit(int x){
	int res=0;
	while(x){
		res++;
		x^=x&-x;
	}
	return res;
}
int n,m,f[55][55],dp1[55][1<<8][1<<8],dp2[55][1<<8][1<<8];
int sum[55][1<<8],ans1,ans2,bit[1<<8];
int main(){
	memset(dp1,0x3f,sizeof(dp1));
	memset(dp2,0x3f,sizeof(dp2));
	ans1=ans2=0x3f3f3f3f;
    read(n,m);
    for(int i=1;i<=n;i++)
		for(int j=0;j<m;j++)
			read(f[i][j]);
	for(int i=0;i<(1<<m);i++)
		bit[i]=getbit(i);
	for(int i=1;i<=n;i++)
		for(int j=0;j<(1<<m);j++)
			for(int k=0;k<m;k++)
				if(j>>k&1)
					sum[i][j]+=f[i][k];
	for(int i=0;i<(1<<m);i++)
		dp1[1][0][i]=sum[1][i],dp2[1][0][i]=bit[i];
	for(int i=2;i<=n;i++)
		for(int j=0;j<(1<<m);j++)
			for(int k=0;k<(1<<m);k++)
				for(int l=0;l<(1<<m);l++){
					if(((j|k|(k<<1)|(k>>1)|l)&((1<<m)-1))!=(1<<m)-1)
						continue;
					int val=dp1[i-1][j][k]+sum[i][l];
					int bitsum=dp2[i-1][j][k]+bit[l];
					if(dp1[i][k][l]>val){
						dp1[i][k][l]=val;
						dp2[i][k][l]=bitsum;
					}
					else if(dp1[i][k][l]==val&&dp2[i][k][l]>bitsum)
						dp2[i][k][l]=bitsum;
				}
	for(int i=0;i<(1<<m);i++)
		for(int j=0;j<(1<<m);j++){
			if(((i|j|(j<<1)|(j>>1))&((1<<m)-1))!=(1<<m)-1)
				continue;
			if(dp1[n][i][j]<ans2){
				ans2=dp1[n][i][j];
				ans1=dp2[n][i][j];
			}
		}
	printf("%d %d\n",ans1,ans2);
    return 0;
}

听课笔记

[春季测试 2023] 圣诞树

因为三角形两边和大于第三边,所以路径不交叉一定比路径交叉优。那么考虑区间 DP。首先断环为链,设 \(dp_{l,r,0/1}\) 表示当前已经走完了区间 \([l,r]\),当前在区间最左/右侧。同时记录 \(pre_{l,r,0/1}\) 表示对应状态的 \(dp\) 是从上一步的左/右侧转移过来的。输出直接递归或用栈就可以了。注意:由于坐标可以为负,所以一定注意边界条件。

code
#include<iostream>
#include<cmath>
#include<algorithm>
using namespace std;
constexpr int N=1e3+10;
int n,s;
struct Node{
    double x,y;
    int id;
}node[N<<1];
double dp[N<<1][N<<1][2],ans;
int pre[N<<1][N<<1][2];
inline double dis(const Node &a,const Node &b){
    double x=a.x-b.x,y=a.y-b.y;
    return sqrt(x*x+y*y);
}
void print(int l,int r,int p){
    if(l==r){
        printf("%d ",node[l].id);
        return;
    }
    if(p){
        print(l,r-1,pre[l][r][p]);
        printf("%d ",node[r].id);
    }
    else{
        print(l+1,r,pre[l][r][p]);
        printf("%d ",node[l].id);
    }
}
int main(){
    ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
    cin>>n;
    node[s].y=-0x7fffffff;
    for(int i=1;i<=n;i++){
        cin>>node[i].x>>node[i].y;
        node[i].id=i;
        node[i+n]=node[i];
        if(node[i].y>node[s].y) s=i;
    }
    for(int i=1;i<=(n<<1);i++)
        for(int j=i;j<=(n<<1);j++)
            dp[i][j][0]=dp[i][j][1]=0x7fffffff;
    dp[s][s][0]=dp[s][s][1]=dp[s+n][s+n][0]=dp[s+n][s+n][1]=0;
    for(int len=2;len<=n;len++)
        for(int l=1,r=l+len-1;r<=(n<<1);l++,r++){
            //0 left ,1 right
            double a=dp[l+1][r][0]+dis(node[l],node[l+1]),b=dp[l+1][r][1]+dis(node[l],node[r]);
            if(a>b) dp[l][r][0]=b,pre[l][r][0]=1;
            else dp[l][r][0]=a,pre[l][r][0]=0;
            a=dp[l][r-1][1]+dis(node[r-1],node[r]),b=dp[l][r-1][0]+dis(node[l],node[r]);
            if(a>b) dp[l][r][1]=b,pre[l][r][1]=0;
            else dp[l][r][1]=a,pre[l][r][1]=1;
        }
    ans=0x7fffffff;
    for(int i=1;i<=n;i++)
        ans=min({ans,dp[i][i+n-1][0],dp[i][i+n-1][1]});
    for(int i=1;i<=n;i++){
        if(dp[i][i+n-1][0]==ans){
            print(i,i+n-1,0);
            return 0;
        }
        if(dp[i][i+n-1][1]==ans){
            print(i,i+n-1,1);
            return 0;
        }
    }
    return 0;
}

[JSOI2018] 潜入行动

树上背包。朴素的状态不足以表达信息时就考虑加状态。\(dp_{u,i,0/1,0/1}\) 表示以 \(u\) 为根的子树中,放了 \(i\) 个监视器,点 \(u\) 放/不放,点 \(u\) 是/否被覆盖。转移时涉及滚动数组,所以我们记 \(temp_{i,0/1,0/1}\) 表示当前的 \(dp_{u,i,0/1,0/1}\)
转移有点麻烦:

\[dp_{u,i+j,0,0}=\sum temp_{i,0,0}\times dp_{v,j,0,1} \]

\[dp_{u,i+j,1,0}=\sum temp_{i,1,0}\times dp_{v,j,0,0/1} \]

\[dp_{u,i+j,0,1}=\sum temp_{i,0,0}\times dp_{v,j,1,1}+temp_{i,0,1}\times dp_{v,j,0/1,1} \]

\[dp_{u,i+j,1,1}=\sum temp_{i,1,0}\times dp_{v,j,1,0/1}+temp_{i,1,1}\times dp_{v,j,0/1,0/1} \]

code
#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;
template<typename T>
inline void read(T &x){
    bool f=0;x=0;char ch=getchar();
    while(ch<'0'||ch>'9'){
        if(ch=='-') f=1;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9'){
        x=(x<<1)+(x<<3)+(ch^48);
        ch=getchar();
    }
    if(f) x=~x+1;
}
template<typename T,typename...Args>
void read(T &x,Args &...args){read(x);read(args...);}
typedef long long ll;
constexpr int mod=1e9+7,N=1e5+10;
struct modint{
    int val;
    modint(int val=0):val(val){}
    modint operator+(const modint &x)const{return modint((val+x.val)%mod);}
    modint &operator+=(const modint &x){val+=x.val;val%=mod;return *this;}
    modint operator*(const modint &x)const{return modint((ll)val*x.val%mod);}
};
vector<int> e[N];
int n,k,siz[N];
modint dp[N][110][2][2],temp[110][2][2];
void dfs(int u,int fa){
    dp[u][1][1][0]=dp[u][0][0][0]=1;
    siz[u]=1;
    for(int v:e[u]){
        if(v==fa) continue;
        dfs(v,u);
        for(int i=0;i<=k;i++){
            temp[i][0][0]=dp[u][i][0][0];
            temp[i][0][1]=dp[u][i][0][1];
            temp[i][1][0]=dp[u][i][1][0];
            temp[i][1][1]=dp[u][i][1][1];
            dp[u][i][0][0]=dp[u][i][0][1]=dp[u][i][1][0]=dp[u][i][1][1]=0;
        }
        for(int i=0;i<=min(siz[u],k);i++)
            for(int j=0;j<=min(siz[v],k-i);j++){
                dp[u][i+j][0][0]+=temp[i][0][0]*dp[v][j][0][1];
                dp[u][i+j][1][0]+=temp[i][1][0]*(dp[v][j][0][0]+dp[v][j][0][1]);
                dp[u][i+j][0][1]+=temp[i][0][0]*dp[v][j][1][1]+
                temp[i][0][1]*(dp[v][j][0][1]+dp[v][j][1][1]);
                dp[u][i+j][1][1]+=temp[i][1][0]*(dp[v][j][1][0]+dp[v][j][1][1])+
                temp[i][1][1]*(dp[v][j][0][0]+dp[v][j][0][1]+dp[v][j][1][0]+dp[v][j][1][1]);
            }
        siz[u]+=siz[v];
    }
}
int main(){
    read(n,k);
    for(int i=1,u,v;i<n;i++){
        read(u,v);
        e[u].push_back(v);
        e[v].push_back(u);
    }
    dfs(1,0);
    printf("%d\n",(dp[1][k][0][1]+dp[1][k][1][1]).val);
    return 0;
}

AT_dp_e Knapsack 2

背包变形。将 \(v\)\(w\) 互换一下就行。

code
#include<cstdio>
#include<cstring>
template<typename T>
inline void read(T &x){
    bool f=0;x=0;char ch=getchar();
    while(ch<'0'||ch>'9'){
        if(ch=='-') f=1;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9'){
        x=(x<<1)+(x<<3)+(ch^48);
        ch=getchar();
    }
    if(f) x=~x+1;
}
template<typename T,typename...Args>
void read(T &x,Args &...args){read(x);read(args...);}
template<typename T>
inline T Min(const T &a,const T &b){return a<b?a:b;}
constexpr int N=1e5+10;
typedef long long ll;
int n,W,w,v,sum;
ll dp[N];
int main(){
    read(n,W);
    memset(dp,0x3f,sizeof(dp));
    dp[0]=0;
    for(int i=1;i<=n;i++){
        read(w,v);
        sum+=v;
        for(int j=sum;j>=v;j--)
            dp[j]=Min(dp[j],dp[j-v]+w);
    }
    for(int i=sum;i;i--)
        if(dp[i]<=W){
            printf("%d\n",i);
            return 0;
        }
    return 0;
}

AT_dp_j Sushi

\(dp_{i,j,k}\) 表示有 \(i\) 个装 \(1\) 个寿司,\(j\) 个装 \(2\) 个寿司和 \(k\) 个装 \(3\) 个寿司的盘子的期望次数。\(dp_{i,j,k}\) 为以下四项之和:

  • \(\dfrac{n-(i+j+k)}{n}\times (dp_{i,j,k}+1)\)(空盘子)
  • \(\dfrac{i}{n}\times (dp_{i-1,j,k}+1)\)(放 \(1\) 个)
  • \(\dfrac{j}{n}\times (dp_{i+1,j-1,k}+1)\)(放 \(2\) 个)
  • \(\dfrac{k}{n}\times (dp_{i,j+1,k-1}+1)\)(放 \(3\) 个)

整理,得

\[dp_{i,j,k}=\frac{n}{i+j+k}+\frac{i\times dp_{i-1,j,k}}{i+j+k}+\frac{j\times dp_{i+1,j-1,k}}{i+j+k}+\frac{k\times dp_{i,j+1,k-1}}{i+j+k} \]

显然,应当按 \(k-j-i\) 顺序枚举以消除后效性。边界:\(dp_{0,0,0}=0\)

code
#include<cstdio>
template<typename T>
inline void read(T &x){
    bool f=0;x=0;char ch=getchar();
    while(ch<'0'||ch>'9'){
        if(ch=='-') f=1;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9'){
        x=(x<<1)+(x<<3)+(ch^48);
        ch=getchar();
    }
    if(f) x=~x+1;
}
template<typename T,typename...Args>
void read(T &x,Args &...args){read(x);read(args...);}
constexpr int N=310;
int n,cnt[4];
double dp[N][N][N];
int main(){
    read(n);
    for(int i=1,a;i<=n;i++) read(a),++cnt[a];
    for(int k=0;k<=n;k++)
        for(int j=0;j<=n;j++)
            for(int i=0;i<=n;i++){
                if(i==j&&j==k&&k==0) continue;
                double inv=i+j+k;
                dp[i][j][k]=n/inv;
                if(i) dp[i][j][k]+=dp[i-1][j][k]*i/inv;
                if(j) dp[i][j][k]+=dp[i+1][j-1][k]*j/inv;
                if(k) dp[i][j][k]+=dp[i][j+1][k-1]*k/inv;
            }
    printf("%.10lf",dp[cnt[1]][cnt[2]][cnt[3]]);
    return 0;
}
posted @ 2025-04-28 17:32  headless_piston  阅读(29)  评论(0)    收藏  举报