题目链接

https://www.lydsy.com/JudgeOnline/problem.php?id=5467

题解

容易发现,强化牌和攻击牌按大的出,只要能出强化牌就出强化牌是最优策略。

首先将两种牌按权值排序,设fi,jf_{i,j}表示前ii张强化牌取到jj张,必定取到第ii张能打出的最大强化值之和,gi,jg_{i,j}表示前ii张攻击牌取到jj张,必定取到第ii张能打出的最大伤害之和,容易发现
fi,j=wi×k=j1i1fk,j1gi,j=(i1j1)wi+k=j1i1fk,j1 f_{i,j}=w_i\times \sum_{k=j-1}^{i-1}f_{k,j-1}\\ g_{i,j}=\binom{i-1}{j-1}w_i+\sum_{k=j-1}^{i-1}f_{k,j-1}
可以前缀和优化。

F(i,j)F(i,j)表示摸到ii张强化牌打出jj张的最大强化值之和,G(i,j)G(i,j)表示摸到ii张攻击牌打出jj张的最大伤害之和。
F(i,j)=k=infk,j(nkij)G(i,j)=k=ingk,j(nkij) F(i,j)=\sum_{k=i}^n f_{k,j}\binom{n-k}{i-j}\\ G(i,j)=\sum_{k=i}^n g_{k,j}\binom{n-k}{i-j}
最后答案为
i=1n1F(i,i)G(mi,ki)+i=nmF(i,n1)G(mi,1) \sum_{i=1}^{n-1} F(i,i)G(m-i,k-i)+\sum_{i=n}^m F(i,n-1)G(m-i,1)

代码

#include <cstdio>
#include <cstring>
#include <algorithm>
 
int read()
{
  int x=0,f=1;
  char ch=getchar();
  while((ch<'0')||(ch>'9'))
    {
      if(ch=='-')
        {
          f=-f;
        }
      ch=getchar();
    }
  while((ch>='0')&&(ch<='9'))
    {
      x=x*10+ch-'0';
      ch=getchar();
    }
  return x*f;
}
 
const int maxn=3000;
const int mod=998244353;
 
int quickpow(int a,int b,int m)
{
  int res=1;
  while(b)
    {
      if(b&1)
        {
          res=1ll*res*a%m;
        }
      a=1ll*a*a%m;
      b>>=1;
    }
  return res;
}
 
bool cmp(int a,int b)
{
  return a>b;
}
 
int n,f[maxn+10][maxn+10],g[maxn+10][maxn+10],fac[maxn+10],ifac[maxn+10];
 
int C(int a,int b)
{
  if(b>a)
    {
      return 0;
    }
  if(a<0)
    {
      return 0;
    }
  if(b<0)
    {
      return 0;
    }
  return 1ll*fac[a]*ifac[b]%mod*ifac[a-b]%mod;
}
 
int F(int a,int b)
{
  if(a>n)
    {
      return 0;
    }
  if(b>n)
    {
      return 0;
    }
  if(a<b)
    {
      return 0;
    }
  int res=0;
  for(int i=b; i<=n; ++i)
    {
      res=(res+1ll*C(n-i,a-b)*f[b][i])%mod;
    }
  return res;
}
 
int G(int a,int b)
{
  if(a>n)
    {
      return 0;
    }
  if(b>n)
    {
      return 0;
    }
  if(a<b)
    {
      return 0;
    }
  int res=0;
  for(int i=b; i<=n; ++i)
    {
      res=(res+1ll*C(n-i,a-b)*g[b][i])%mod;
    }
  return res;
}
 
int T,m,k,a[maxn+10],b[maxn+10],sf[maxn+10][maxn+10],sg[maxn+10][maxn+10];
 
int main()
{
  fac[0]=1;
  for(int i=1; i<=maxn; ++i)
    {
      fac[i]=1ll*fac[i-1]*i%mod;
    }
  ifac[0]=ifac[1]=1;
  for(int i=2; i<=maxn; ++i)
    {
      ifac[i]=1ll*(mod-mod/i)*ifac[mod%i]%mod;
    }
  for(int i=1; i<=maxn; ++i)
    {
      ifac[i]=1ll*ifac[i]*ifac[i-1]%mod;
    }
  T=read();
  while(T--)
    {
      n=read();
      m=read();
      k=read();
      for(int i=1; i<=n; ++i)
        {
          a[i]=read();
        }
      for(int i=1; i<=n; ++i)
        {
          b[i]=read();
        }
      std::sort(a+1,a+n+1,cmp);
      std::sort(b+1,b+n+1,cmp);
      f[0][0]=1;
      g[0][0]=0;
      for(int i=0; i<=n; ++i)
        {
          sf[0][i]=1;
          sg[0][i]=0;
        }
      for(int i=1; i<=n; ++i)
        {
          for(int j=0; j<i; ++j)
            {
              f[i][j]=g[i][j]=sf[i][j]=sg[i][j]=0;
            }
          for(int j=i; j<=n; ++j)
            {
              f[i][j]=1ll*a[j]*sf[i-1][j-1]%mod;
              g[i][j]=(1ll*C(j-1,i-1)*b[j]+sg[i-1][j-1])%mod;
              sf[i][j]=sf[i][j-1]+f[i][j];
              if(sf[i][j]>=mod)
                {
                  sf[i][j]-=mod;
                }
              sg[i][j]=sg[i][j-1]+g[i][j];
              if(sg[i][j]>=mod)
                {
                  sg[i][j]-=mod;
                }
            }
        }
      int ans=0;
      for(int i=0; i<k; ++i)
        {
          ans=(ans+1ll*F(i,i)*G(m-i,k-i))%mod;
        }
      for(int i=k; i<=m; ++i)
        {
          ans=(ans+1ll*F(i,k-1)*G(m-i,1))%mod;
        }
      printf("%d\n",ans);
    }
  return 0;
}