洛谷P9838 挑战 NPC IV

0x1f 省流

给出一个 $1 \to n $ 排列 \(a_1 \to a_n\),该排列的权值为 \(S\)

\[S = \sum_{l=1}^n \sum_{r=l}^n \sum_{i=l}^r \big(\log2(lowbit(a_i)) + 1 \big) \]

求第 \(k\) 小的 \(S\)

0x2f 关于思路

\(NPC\) 问题显然可做

\(\textbf{Subtask 1: } n \le 10\)

暴力枚举排列、暴力枚举区间、暴力计算贡献、暴力······

\(O(n^2n!)\)

\(\textbf{Subtask 2: } k \le 2\)

\[f(x) = \log_2(lowbit(a_i)) + 1 \]

这个函数,发现一定存在 \(f(a)=f(b)\),相当的廉价。

什么意思呢?

我们在排列中直接交换 \(a,b\),发现 \(S\) 不变,实际上求最小的 \(S\)

关于 \(k\) 的约束条件消失了。

首先考虑 \(k=1\) 的时候怎么做。

考虑最终的位置 \(i\) 会被多少个区间统计到,不难发现它计入贡献的次数是 \(b_i = i \times (n-i+1)\)
也就是说问题转化成,令 \(a_i = f(i)\)\(b_i = i \times (n-i+1)\)。,我们要对 \(a\) 序列指定一个顺序,最小化 \(\sum_{i=1}^n a_i \times b_i\)

这是我们小学二年级就学过的贪心。

\(O(n)\) 过掉。

\(\textbf{Subtask 3: } 10^5 \le n \le 10^6\)

\(f(a)=f(b)\),没错就是这个廉价的性质。

经过初步计算,\(f(a) = f(b) = 1\) 约有 \(\dfrac{n}{2}\) 个解,\(f(a) = f(b) = 2\) 约有 \(\dfrac{n}{4}\) 个解.

所以总的最值方案数共有 \(\prod_{i=1}^{\log_2 n} {\dfrac{n}{2^i}!}\) 种。

经过实际精确的计算,在 \(n = 29\) 时,总最值方案数超过了 \(10^{18}\)

因此我们可以继续贪心。

\(O(n)\) 过掉。

0x3f 关于正解

事实上,在上一节我们已经积累了很多奇技淫巧来完成这道题。

这里采用数据 点分治 的做法。

\(\textbf{Case 1: } n > 28\)

\(O(n)\) 不足以让我们过掉 \(10^{18}\) 的数据。

显然可以优化 \(O(n)\) 的做法。

考虑 \(\sum_{i=1}^n a_i \times b_i\) 这个式子。

我们之前的做法是排序后一个一个计算。

事实上,\(a_i\) 是重复的,将一段 \(a_i\) 提出来,求

\[a_i \times \sum_{i=l}^r b_i \]

我们展开式子:

\[\begin{align} &\sum_{i=l}^r b_i \\ =& \sum_{i=l}^r i \times (n-i+1) \\ =& \sum_{i=l}^r \big(i \times (n+1) - i^2 \big) \\ =& (n+1) \times \sum_{i=l}^r i - \sum_{i=l}^r i^2 \\ =& (n+1) \times \dfrac{(l+r)\times(r-l+1)}{2} - \dfrac{r\times(r+1)\times(2r+1)}{6} + \dfrac{l\times(l-1)\times(2l-1)}{6} \end{align} \]

可以看到,对于每一个 \(a_i\),可以 \(O(1)\) 求出 \(a_i \times \sum_{i=l}^r b_i\)

因此复杂度 \(O(\log n)\)

\(\textbf{Case 2: } n \le 28\)

只需要 \(\textbf{DP}\)

因为选一个数的方案时,只有前面选过的数量影响方案数。

\(\because \log28 \le 5\)

预处理 \(cnt\) 数组, \(cnt_i\) 表示 \(1 \to n\) 一共有多少个 \(f(x) = i\)

我们设 \(\large dp_{a,b,c,d,e,sum}\) 表示当前选了 \(a\)\(f(1)\)\(b\)\(f(2)\)\(c\)\(f(3)\)\(d\)\(f(4)\)\(e\)\(f(5)\),当前产生贡献 \(sum\)

显然当前一共枚举了 \(tot = a + b + c + d + e\) 个数。

有比较简单的转移方程(在这里不计边界条件):

\[\large dp_{a,b,c,d,e,sum} += \begin{cases} dp_{a-1,b,c,d,e,sum-1 \times tot \times (n-tot+1)} \times (cnt_1 - (a-1)) \\ dp_{a,b-1,c,d,e,sum-2 \times tot \times (n-tot+1)} \times (cnt_2 - (b-1)) \\ dp_{a,b,c-1,d,e,sum-3 \times tot \times (n-tot+1)} \times (cnt_3 - (c-1)) \\ dp_{a,b,c,d-1,e,sum-4 \times tot \times (n-tot+1)} \times (cnt_4 - (d-1)) \\ dp_{a,b,c,d,e-1,sum-5 \times tot \times (n-tot+1)} \times (cnt_5 - (e-1)) \end{cases} \]

总复杂度即总状态数,\(n=28\) 时约 \(O(10^7)\)

一定很简单对吧!!

0x4f 关于代码

码风勿喷,放两种格式的。

希望大家能看懂 (づ′▽`)づ~

\(\huge \mathscr{Code}\)

#include<bits/stdc++.h>
#define int long long 
using namespace std;
const int M = 1e4+5,MOD = 998244353,inv2 = 499122177,inv6 = 166374059;
int n,m,q,k;
int dp[16][9][5][3][2][M],cnt[6];
int f(int x){
	return log2(x&-x) + 1;
}
int solve1(){
	dp[0][0][0][0][0][0] = 1;
	memset(cnt,0,sizeof(cnt));
	for(int i=1;i<=n;i++) cnt[f(i)]++;
	for(int a=0;a<=cnt[1];a++){
		for(int b=0;b<=cnt[2];b++){
			for(int c=0;c<=cnt[3];c++){
				for(int d=0;d<=cnt[4];d++){
					for(int e=0;e<=cnt[5];e++){
						int tot = a+b+c+d+e,contribute = tot*(n-tot+1);
						for(int sum=1;sum<M;sum++){
							dp[a][b][c][d][e][sum] = 0;
							if(a && sum>=1*contribute) dp[a][b][c][d][e][sum] += dp[a-1][b][c][d][e][sum - 1*contribute] * (cnt[1]-a+1);
							if(b && sum>=2*contribute) dp[a][b][c][d][e][sum] += dp[a][b-1][c][d][e][sum - 2*contribute] * (cnt[2]-b+1);
							if(c && sum>=3*contribute) dp[a][b][c][d][e][sum] += dp[a][b][c-1][d][e][sum - 3*contribute] * (cnt[3]-c+1);
							if(d && sum>=4*contribute) dp[a][b][c][d][e][sum] += dp[a][b][c][d-1][e][sum - 4*contribute] * (cnt[4]-d+1);
							if(e && sum>=5*contribute) dp[a][b][c][d][e][sum] += dp[a][b][c][d][e-1][sum - 5*contribute] * (cnt[5]-e+1);
						}
					}
				}
			}
		}
	}
	int rec = 0;
	for(int i=1;i<M;i++){
		rec += dp[cnt[1]][cnt[2]][cnt[3]][cnt[4]][cnt[5]][i];
		if(rec>=k) return i;
	}
	return 0;
}
int sum(int l,int r){
	if(l>r) return 0;
	l %= MOD,r %= MOD;
	int s1 = (n+1)%MOD*(l+r)%MOD*(r-l+1)%MOD*inv2%MOD;
	int s2 = r*(r+1)%MOD*(2*r+1)%MOD*inv6%MOD;
	int s3 = l*(l-1)%MOD*(2*l-1)%MOD*inv6%MOD;
	return ((s1-s2+s3)%MOD+MOD)%MOD;
}
int solve2(){
	int l = 1,r = n,ans = 0;
	for(int i=log2(n)+1;i;i--){
		int c = (n>>(i-1)) - (n>>i);
		int dl = c>>1,dr = c-dl;
		if(l<n-r+1) swap(dl,dr);
		ans = (ans + i*sum(l,l+dl-1))%MOD;
		ans = (ans + i*sum(r-dr+1,r))%MOD;
		l += dl,r -= dr;
	}
	return ans;
}
signed main(){
	ios::sync_with_stdio(false);
	cin.tie(0),cout.tie(0);
	cin>>q;
	while(q--){
		cin>>n>>k;
		int ans = n<=28?solve1():solve2();
		cout<<ans<<'\n';
	}
	return 0;
}

#include<bits/stdc++.h>
#define int long long 
using namespace std;
const int M = 1e4 + 5, MOD = 998244353, inv2 = 499122177, inv6 = 166374059;
int n, m, q, k;
int dp[16][9][5][3][2][M], cnt[6];
int f(int x) {
	return log2(x & -x) + 1;
}
int solve1() {
	dp[0][0][0][0][0][0] = 1;
	memset(cnt, 0, sizeof(cnt));
	for (int i = 1; i <= n; i++) cnt[f(i)]++;
	for (int a = 0; a <= cnt[1]; a++) {
		for (int b = 0; b <= cnt[2]; b++) {
			for (int c = 0; c <= cnt[3]; c++) {
				for (int d = 0; d <= cnt[4]; d++) {
					for (int e = 0; e <= cnt[5]; e++) {
						int tot = a + b + c + d + e, contribute = tot * (n - tot + 1);
						for (int sum = 1; sum < M; sum++) {
							dp[a][b][c][d][e][sum] = 0;
							if (a && sum >= 1 * contribute) dp[a][b][c][d][e][sum] += dp[a - 1][b][c][d][e][sum - 1 * contribute] * (cnt[1] - a + 1);
							if (b && sum >= 2 * contribute) dp[a][b][c][d][e][sum] += dp[a][b - 1][c][d][e][sum - 2 * contribute] * (cnt[2] - b + 1);
							if (c && sum >= 3 * contribute) dp[a][b][c][d][e][sum] += dp[a][b][c - 1][d][e][sum - 3 * contribute] * (cnt[3] - c + 1);
							if (d && sum >= 4 * contribute) dp[a][b][c][d][e][sum] += dp[a][b][c][d - 1][e][sum - 4 * contribute] * (cnt[4] - d + 1);
							if (e && sum >= 5 * contribute) dp[a][b][c][d][e][sum] += dp[a][b][c][d][e - 1][sum - 5 * contribute] * (cnt[5] - e + 1);
						}
					}
				}
			}
		}
	}
	int rec = 0;
	for (int i = 1; i < M; i++) {
		rec += dp[cnt[1]][cnt[2]][cnt[3]][cnt[4]][cnt[5]][i];
		if (rec >= k) return i;
	}
	return 0;
}
int sum(int l, int r) {
	if (l > r) return 0;
	l %= MOD, r %= MOD;
	int s1 = (n + 1) % MOD * (l + r) % MOD * (r - l + 1) % MOD * inv2 % MOD;
	int s2 = r * (r + 1) % MOD * (2 * r + 1) % MOD * inv6 % MOD;
	int s3 = l * (l - 1) % MOD * (2 * l - 1) % MOD * inv6 % MOD;
	return ((s1 - s2 + s3) % MOD + MOD) % MOD;
}
int solve2() {
	int l = 1, r = n, ans = 0;
	for (int i = log2(n) + 1; i; i--) {
		int c = (n >> (i - 1)) - (n >> i);
		int dl = c >> 1, dr = c - dl;
		if (l < n - r + 1) swap(dl, dr);
		ans = (ans + i * sum(l, l + dl - 1)) % MOD;
		ans = (ans + i * sum(r - dr + 1, r)) % MOD;
		l += dl, r -= dr;
	}
	return ans;
}
signed main() {
	ios::sync_with_stdio(false);
	cin.tie(0), cout.tie(0);
	cin >> q;
	while (q--) {
		cin >> n >> k;
		int ans = n <= 28 ? solve1() : solve2();
		cout << ans << '\n';
	}
	return 0;
}

posted @ 2025-08-08 16:05  OrangeRED  阅读(35)  评论(0)    收藏  举报