题解:luogu_P13696 「CyOI」出包魔法师
非常好的一道数学题。
原题链接。
题目分析
我们要意识到的一点是,题目中要求的最优策略,实际上是一个固定的报数的序列,即对于每一个数字 \(i\),你要报 \(a_i\) 次这个数字,并要使得能拿走 \(k\) 张卡牌的概率最大。
那么最终的答案就是
\[\prod_{i = 1}^m \binom{l_i}{a_i}
\]
并且满足 \(\sum_{i = 1}^n a_i = k\)
初始时,每个 \(a_i\) 都为 \(0\),我们将 \(a_i\) 加 \(1\) 对答案的贡献相当于乘上了 \(\dfrac{l_i - a_i}{a_i + 1}\)(由 \(\dbinom{l_i}{a_i}\) 变为 \(\dbinom{l_i}{a_i + 1}\))。
这时就已经有一个贪心的思路了,每次选择一个 \(\dfrac{l_i - a_i}{a_i + 1}\) 最大的数字 \(i\),将其 \(a_i\) 加 \(1\)。可以用一个堆来维护。复杂度 \(O(k \log m)\)。
优化
我们发现如果将 \(l_i\) 从小到大排序,越靠后的数选的一定越多(有点废话)。我们二分 \(l_m\) 选了 \(x\) 个,对于 \(a_i\),我们想让它的贡献尽可能的大。那么就是
\[\dfrac{l_i - (a_i - 1)}{a_i} \ge \dfrac{l_m - (x - 1)}{x}
\]
解得 \(a_i \le \dfrac{(l_i + 1)x}{l_m + 1}\)。说明 \(a_i\) 随 \(x\) 增大而增大,通过 \(\sum a_i\) 与 \(k\) 的大小关系调整二分。
最后有可能会没选满 \(k\),这时一定满足 \(m - k - 1 \le \sum a_i \le k\),此时再使用一个堆来贪心地维护就可以了。
加上线性预处理阶乘及其逆元,时间复杂度 \(O(\max(l_i) + m \log \max(l_i))\)。
Code
#include <bits/stdc++.h>
using namespace std;
#define int long long
constexpr int N = 1e6 + 9;
constexpr int V = 1e7 + 9;
constexpr int mod = 998244353;
int fac[V], ifac[V];
int l[N], a[N];
int m, k;
int ans = 1ll;
struct Node{
int l, a, id;
friend bool operator < (Node x, Node y){return (x.l - x.a) * (y.a + 1) < (y.l - y.a) * (x.a + 1);}
};
priority_queue<Node>q;
int fp(int a, int n){
int res = 1ll;
while(n){
if(n & 1) res = res * a % mod;
a = a * a % mod;
n >>= 1ll;
}
return res % mod;
}
int C(int n, int m){return fac[n] * ifac[m] % mod * ifac[n - m] % mod;}
bool check(int x){
int res = 0;
for(int i = 1; i <= m; i++) res += (l[i] + 1) * x / (l[m] + 1);
return res <= k;
}
void init(){
fac[0] = 1ll;
for(int i = 1; i < V; i++) fac[i] = 1ll * fac[i - 1] * i % mod;
ifac[V - 1] = fp(fac[V - 1], mod - 2); ifac[0] = 1ll;
for(int i = V - 2; i >= 1; i--) ifac[i] = 1ll * ifac[i + 1] * (i + 1) % mod;
}
signed main(){
init();
cin >> m >> k;
for(int i = 1; i <= m; i++) cin >> l[i];
sort(l + 1, l + 1 + m);
int L = 1, R = l[m], p;
while(L <= R){
int mid = (L + R) >> 1;
if(check(mid)){
L = mid + 1;
p = mid;
}
else R = mid - 1;
}
for(int i = 1; i <= m; i++){
a[i] = (l[i] + 1) * p / (l[m] + 1);
k -= a[i];
if(l[i] > a[i]) q.push((Node){l[i], a[i], i});
}
while(k--){
int i = q.top().id; q.pop();
a[i]++;
if(l[i] > a[i]) q.push((Node){l[i], a[i], i});
}
for(int i = 1; i <= m; i++){
ans = ans * C(l[i], a[i]) % mod;
}
cout << ans % mod;
return 0;
}