洛谷 P5299 - [PKUWC2018]Slay the Spire(组合数学+dp)

题面传送门

hot tea 啊……这种风格及难度的题放在省选 D2T1 左右还是挺喜闻乐见的罢

首先考虑对于固定的 \(m\) 张牌怎样求出最优的打牌策略,假设我们抽到了 \(p\) 张强化牌,攻击力从大到小分别为 \(x_1,x_2,\cdots,x_p\),以及 \(q\) 张攻击牌,攻击力从大到小分别为 \(y_1,y_2,\cdots,y_q\),显然如果 \(q=0\) 那就没得打了,总攻击力显然为 \(0\),否则你手玩几组数据就能发现我们肯定会尽量打强化牌直到没有强化牌或者只能再打 \(1\) 张牌为止,打完了强化牌再打攻击牌,证明就套路地用下 exchange arguments,假设现在我们选择了 \(s(s<p)\) 张强化牌与 \(t(t\ge 2)\) 张攻击牌,那么造成的伤害的最大值显然为 \(W_1=\prod\limits_{i=1}^sx_i\sum\limits_{j=1}^ty_j\),考虑多打一张攻击牌,那么造成的伤害就变为 \(W_2=\prod\limits_{i=1}^{s+1}x_i\sum\limits_{j=1}^{t-1}y_j\),这里做差不太容易,故考虑做商,\(\dfrac{W_2}{W_1}=x_{s+1}\times\dfrac{\sum\limits_{j=1}^ty_j}{\sum\limits_{j=1}^{t+1}y_j}\),由于 \(t\ge 2\)\(y_j\ge y_{j+1}\),故一定有 \(y_{t+1}\le\sum\limits_{j=1}^ty_j\),故 \(\dfrac{\sum\limits_{j=1}^ty_j}{\sum\limits_{j=1}^{t+1}y_j}\le\dfrac{1}{2}\),而题目规定 \(x_i\ge 2\),故 \(\dfrac{W_2}{W_1}\ge 1\),即 \(W_2\ge W_1\),因此我们的策略是最优的。

接下来将这个结论应用于原题,首先将所有强化牌和攻击牌按从大到小顺序排序。考虑将选择的 \(m\) 张牌中强化牌与攻击牌的数量分为两类:强化牌数量 \(<k-1\)\(\ge k-1\)。我们预处理 \(dp1_{i,j}\) 表示在前 \(i\) 张强化牌中选择 \(j\) 张强化牌,并且强化牌 \(i\) 必须被选择,所有选牌方案的强化牌上值的乘积之和,再预处理 \(dp2_{i,j}\) 表示在前 \(i\) 张攻击牌中选 \(j\) 张,所有选牌方案的值之和的和。对于强化牌数量 \(\le k-1\) 的情况,我们枚举以下三个量:选择的强化牌数量 \(c\in[0,k-2]\)、最后一个(这里及下文中的“最后一个”指下标最大)被选择的强化牌编号 \(i\),以及最后一个被打出去的攻击牌的编号 \(j\),剩下 \(m-k\) 张牌显然只能在剩余 \(n-j\) 张攻击牌中选,产生的贡献为 \(dp1_{i,c}\times dp2_{j,k-c}\times\dbinom{n-j}{m-k}\)\(i\) 的那一维显然可以前缀和优化掉,复杂度 \(n^2\)。对于强化牌数量 \(>k-1\) 的情况,我们枚举最后一个被打出去的强化牌的编号 \(i\),以及唯一一个被出去的攻击牌的编号 \(j\),剩余 \(m-k\) 张牌可以在剩余 \(n-i\) 张强化牌和 \(n-j\) 张攻击牌中选,产生的贡献就是 \(dp1_{i,k-1}\times b_j\times\dbinom{2n-i-j}{m-k}\),随便枚举一下算一算即可。

最后就是怎样预处理 \(dp1_{i,j},dp2_{i,j}\) 的问题了,其实非常容易,显然有 \(dp\) 方程 \(dp1_{i,j}=\sum\limits_{k=0}^{i-1}dp1_{k,j-1}\times a_i\)\(dp2_{i,j}=\sum\limits_{k=0}^{i-1}dp2_{k,j-1}+a_i\times\dbinom{i-1}{j-1}\)\(\sum\) 那一维显然可以前缀和优化掉,复杂度平方,于是这题就做完了。

const int MAXN=3000;
const int MOD=998244353;
int n,m,k,a[MAXN+5],b[MAXN+5],c[MAXN*2+5][MAXN*2+5];
bool cmp(int lhs,int rhs){return lhs>rhs;}
int dpa[MAXN+5][MAXN+5],sdpa[MAXN+5][MAXN+5];
int dpb[MAXN+5][MAXN+5],sdpb[MAXN+5][MAXN+5];
void solve(){
	scanf("%d%d%d",&n,&m,&k);
	for(int i=1;i<=n;i++) scanf("%d",&a[i]);
	for(int i=1;i<=n;i++) scanf("%d",&b[i]);
	sort(a+1,a+n+1,cmp);sort(b+1,b+n+1,cmp);
	dpa[0][0]=sdpa[0][0]=1;
	for(int i=1;i<=n;i++){
		sdpa[i][0]=1;
		for(int j=1;j<=i;j++){
			dpa[i][j]=1ll*a[i]*sdpa[i-1][j-1]%MOD;
			sdpa[i][j]=(sdpa[i-1][j]+dpa[i][j])%MOD;
//			printf("%d %d %d\n",i,j,dpa[i][j]);
		}
	}
	for(int i=1;i<=n;i++) for(int j=1;j<=i;j++){
		dpb[i][j]=(1ll*b[i]*c[i-1][j-1]%MOD+sdpb[i-1][j-1])%MOD;
		sdpb[i][j]=(sdpb[i-1][j]+dpb[i][j])%MOD;
//		printf("%d %d %d\n",i,j,dpb[i][j]);
	} int ans=0;
	for(int i=0;i<k-1;i++) for(int j=1;j<=n;j++)
		ans=(ans+1ll*sdpa[n][i]*dpb[j][k-i]%MOD*c[n-j][m-k])%MOD;
	for(int i=0;i<=n;i++) for(int j=1;j<=n;j++) if(2*n-i-j>=0)
		ans=(ans+1ll*dpa[i][k-1]*b[j]%MOD*c[2*n-i-j][m-k])%MOD;
	printf("%d\n",ans);
}
int main(){
	for(int i=0;i<=MAXN*2;i++){
		c[i][0]=1;
		for(int j=1;j<=i;j++) c[i][j]=(c[i-1][j-1]+c[i-1][j])%MOD;
	}
	int qu;scanf("%d",&qu);while(qu--) solve();
	return 0;
}
posted @ 2021-04-07 17:05  tzc_wk  阅读(63)  评论(0)    收藏  举报