[LOJ2538] [PKUWC2018] Slay the Spire

题目链接

LOJ:https://loj.ac/problem/2538

Solution

计数好题。

首先可以发现这题和期望没关系。

其次对于手上的一套牌,设我们有\(a\)张强化牌,那么:

  • 如果\(a\geqslant k-1\),那么我们显然是从大到小打出\(k-1\)张强化牌,最后打出一张最大的攻击牌。
  • \(\rm otherwise\),我们打出所有的强化牌,再从大到小打出攻击牌。

那么就可以\(dp\)了。

对于强化牌,我们从大到小排序,设\(f[i][j]\)表示当前考虑了前\(i\)种牌,打出了\(j\)种,所有方案的倍率之和。

那么可以得到转移:

  • \(j\leqslant k-1\),我们显然打出这张牌是最优的,\(f[i][j]=f[i-1][j]+f[i-1][j-1]\cdot w[i]\)
  • \(\rm otherwise\),选或不选这张牌我们都不打出,\(f[i][j]=f[i-1][j]+f[i][j]\)

对于攻击牌,我们从小到大排序,设\(g[i][j]\)表示当前考虑了前\(i\)种牌,打出了\(j\)种,所有方案的伤害之和。

  • \(j\leqslant m-(k-1)\),此时我们只能打出一张牌,\(g[i][j]=g[i-1][j]+\binom{i-1}{j-1}\cdot w[i]\)
  • \(\rm otherwise\),我们可以打出多张牌,且应该尽量打后面的牌,\(g[i][j]=g[i-1][j-1]+g[i-1][j]+\binom{i-1}{j-1}\cdot w[i]\)

第一位可以逆循环然后去掉。

最后答案就是\(ans=\sum_{i=0}^m f[i]g[m-i]\)

#include<bits/stdc++.h>
using namespace std;

void read(int &x) {
    x=0;int f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
    for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}

void print(int x) {
    if(x<0) putchar('-'),x=-x;
    if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}

#define lf double
#define ll long long 

const int maxn = 2e5+10;
const int inf = 1e9;
const lf eps = 1e-8;
const int mod = 998244353;

int add(int x,int y) {return x+y>mod?x+y-mod:x+y;}
int del(int x,int y) {return x-y<0?x-y+mod:x-y;}
int mul(int x,int y) {return 1ll*x*y-1ll*x*y/mod*mod;}

int n,m,k,a[maxn],b[maxn],f[maxn],g[maxn],fac[maxn],ifac[maxn],inv[maxn];

void prepare() {
	inv[0]=inv[1]=fac[0]=ifac[0]=1;
	for(int i=2;i<=3000;i++) inv[i]=mul(mod-mod/i,inv[mod%i]);
	for(int i=1;i<=3000;i++) fac[i]=mul(fac[i-1],i);
	for(int i=1;i<=3000;i++) ifac[i]=mul(ifac[i-1],inv[i]);
}

int c(int x,int y) {return x>=y?mul(fac[x],mul(ifac[y],ifac[x-y])):0;}

void solve() {
	memset(f,0,sizeof f);
	memset(g,0,sizeof g);
	read(n),read(m),read(k);
	for(int i=1;i<=n;i++) read(a[i]);
	for(int i=1;i<=n;i++) read(b[i]);
	sort(a+1,a+n+1,greater<int> ());
	f[0]=1;
	for(int i=1;i<=n;i++)
		for(int j=n;j;j--)
			if(j<=k-1) f[j]=add(f[j],mul(f[j-1],a[i]));
			else f[j]=add(f[j],f[j-1]);
	sort(b+1,b+n+1);
	for(int i=1;i<=n;i++)
		for(int j=n;j;j--)
			if(j<=m-k+1) g[j]=add(g[j],mul(c(i-1,j-1),b[i]));
			else g[j]=add(g[j],add(g[j-1],mul(c(i-1,j-1),b[i])));
	int ans=0;
	for(int i=0;i<=m;i++) ans=add(ans,mul(f[i],g[m-i]));
	write(ans);
}

int main() {
	prepare();
	int t;read(t);while(t--) solve();
	return 0;
}
posted @ 2019-04-12 10:56  Hyscere  阅读(252)  评论(0编辑  收藏  举报