题解:luogu_P13696 「CyOI」出包魔法师

非常好的一道数学题。

原题链接

题目分析

我们要意识到的一点是,题目中要求的最优策略,实际上是一个固定的报数的序列,即对于每一个数字 \(i\),你要报 \(a_i\) 次这个数字,并要使得能拿走 \(k\) 张卡牌的概率最大。

那么最终的答案就是

\[\prod_{i = 1}^m \binom{l_i}{a_i} \]

并且满足 \(\sum_{i = 1}^n a_i = k\)

初始时,每个 \(a_i\) 都为 \(0\),我们将 \(a_i\)\(1\) 对答案的贡献相当于乘上了 \(\dfrac{l_i - a_i}{a_i + 1}\)(由 \(\dbinom{l_i}{a_i}\) 变为 \(\dbinom{l_i}{a_i + 1}\))。

这时就已经有一个贪心的思路了,每次选择一个 \(\dfrac{l_i - a_i}{a_i + 1}\) 最大的数字 \(i\),将其 \(a_i\)\(1\)。可以用一个堆来维护。复杂度 \(O(k \log m)\)

优化

我们发现如果将 \(l_i\) 从小到大排序,越靠后的数选的一定越多(有点废话)。我们二分 \(l_m\) 选了 \(x\) 个,对于 \(a_i\),我们想让它的贡献尽可能的大。那么就是

\[\dfrac{l_i - (a_i - 1)}{a_i} \ge \dfrac{l_m - (x - 1)}{x} \]

解得 \(a_i \le \dfrac{(l_i + 1)x}{l_m + 1}\)。说明 \(a_i\)\(x\) 增大而增大,通过 \(\sum a_i\)\(k\) 的大小关系调整二分。

最后有可能会没选满 \(k\),这时一定满足 \(m - k - 1 \le \sum a_i \le k\),此时再使用一个堆来贪心地维护就可以了。

加上线性预处理阶乘及其逆元,时间复杂度 \(O(\max(l_i) + m \log \max(l_i))\)

Code

#include <bits/stdc++.h>
using namespace std;

#define int long long

constexpr int N = 1e6 + 9;
constexpr int V = 1e7 + 9;
constexpr int mod = 998244353;
int fac[V], ifac[V];
int l[N], a[N];
int m, k;
int ans = 1ll;

struct Node{
	int l, a, id;
	friend bool operator < (Node x, Node y){return (x.l - x.a) * (y.a + 1) < (y.l - y.a) * (x.a + 1);}
};
priority_queue<Node>q;

int fp(int a, int n){
	int res = 1ll;
	while(n){
		if(n & 1) res = res * a % mod;
		a = a * a % mod;
		n >>= 1ll;
	}
	return res % mod;
}

int C(int n, int m){return fac[n] * ifac[m] % mod * ifac[n - m] % mod;}

bool check(int x){
	int res = 0;
	for(int i = 1; i <= m; i++) res += (l[i] + 1) * x / (l[m] + 1);
	return res <= k;
}

void init(){
	fac[0] = 1ll;
	for(int i = 1; i < V; i++) fac[i] = 1ll * fac[i - 1] * i % mod;
	ifac[V - 1] = fp(fac[V - 1], mod - 2); ifac[0] = 1ll;
	for(int i = V - 2; i >= 1; i--) ifac[i] = 1ll * ifac[i + 1] * (i + 1) % mod;
}

signed main(){
	init();
	cin >> m >> k;
	for(int i = 1; i <= m; i++) cin >> l[i];
	sort(l + 1, l + 1 + m);

	int L = 1, R = l[m], p;
	while(L <= R){
		int mid = (L + R) >> 1;
		if(check(mid)){
			L = mid + 1;
			p = mid;
		}
		else R = mid - 1;
	}
	for(int i = 1; i <= m; i++){
		a[i] = (l[i] + 1) * p / (l[m] + 1);
		k -= a[i];
		if(l[i] > a[i]) q.push((Node){l[i], a[i], i});
	}
	while(k--){
		int i = q.top().id; q.pop();
		a[i]++;
		if(l[i] > a[i]) q.push((Node){l[i], a[i], i});
	}

	for(int i = 1; i <= m; i++){
		ans = ans * C(l[i], a[i]) % mod;
	}
	cout << ans % mod;
	return 0;
}

posted @ 2025-08-19 20:51  dairuize  阅读(15)  评论(0)    收藏  举报