CF2003F. Turtle and Three Sequences

给你三个长为 \(n\) 的序列 \(a,b,c\)

求所有满足一下条件的 \([1,2,\cdots,n]\) 的长为 \(m\) 的子序列 \(p_1,p_2,\cdots,p_m\) 中,\(\sum_{i=1}^m c_{p_i}\) 的最大值

  • \(a_{p_1}\le a_{p_2}\le\cdots\le a_{p_m}\)
  • \(\forall i \neq j,b_{p_i} \neq b_{p_j}\),即 \(b_i\) 互异。

\(1\le a_i,b_i\le n\le 3\times 10^3,1\le m\le 5\)


我们发现 \(b_i\) 互不相同这个条件非常难办,所以肯定要把它给转化掉。

有两种方法。

第一种是我们考虑,假如把第二个条件改成 \(b\) 单增,那么就是二维偏序,可以 \(\mathcal O(nm\log^2 n)\) 解决。

那么我们每次把 \(b\) 映射到一个随机的排列,有 \(\frac{1}{m!}\approx .008\) 的概率答案中的 \(b\) 被映射到单增序列上。

那么做多次即可,时间复杂度 \(\mathcal O(nm\log^2n \times m!)\)

第二种是题解做法,我们考虑把每种颜色映射到 \(1\sim m\) 中的随机一种,然后直接状压,复杂度 \(O(n\log n2^m)\)

正确率为答案中的 \(b\) 刚好被映射到排列上,也就是 \(\frac{m!}{m^m} \approx .03\)

时间复杂度 \(\mathcal O(2^mn\log n \times \frac{m^m}{m!})\)


我写的第一种,卡了好久才过 /ll

#include <algorithm>
#include <iostream>
#include <random>

const int N = 3001, M = 512;

std::mt19937 rnd(std::random_device{}());

auto upd = [](auto& x, auto&& y) {
	x = std::max(x, y);
};

struct Matr {
	int matr[N][M];
	
	inline void add(int x, int y, int z) {
		for(; x < N; x += x & -x)
			for(int j = y; j < M; j += j & -j)
				upd(matr[x][j], z);
	}
	
	inline int sum(int x, int y) {
		int z = 0;
		for(; x; x -= x & -x)
			for(int j = y; j; j -= j & -j)
				upd(z, matr[x][j]);
		return z;
	}
	
	inline void clear(int x, int y) {
		for(; x < N; x += x & -x)
			for(int j = y; j < M; j += j & -j)
				matr[x][j] = 0;
	}
};

Matr Mat[4];
int n, m, a[N], rb[N], b[N], c[N], w[N], q[N];

int solve() {
	std::shuffle(w + 1, w + n + 1, rnd);
	for(int i = 1; i <= n; ++i) b[i] = w[rb[i]];
	int ans = 0;
	for(int i = 1; i <= n; ++i) {
		Mat[0].add(a[i], b[i], c[i]);
		for(int t = 1; t < m - 1; ++t)
			if(auto val = Mat[t-1].sum(a[i], b[i]-1); val)
				Mat[t].add(a[i], b[i], val + c[i]);
		if(auto val = Mat[m-2].sum(a[i], b[i]-1); val)
			upd(ans, val + c[i]);
	}
	for(int i = 1; i <= n; ++i) 
		for(int x = a[i]; x < N; x += x & -x)
			for(int y = b[i]; y < M; y += y & -y)
				for(int t = 0; t < m - 1; ++t)
					Mat[t].matr[x][y] = 0;
	return ans;
}

int T = clock();

int main() {
	std::ios::sync_with_stdio(0), std::cin.tie(0), std::cout.tie(0);
	std::cin >> n >> m;
	for(int i = 1; i <= n; ++i) std::cin >> a[i];
	for(int i = 1; i <= n; ++i) std::cin >> rb[i];
	for(int i = 1; i <= n; ++i) std::cin >> c[i];
	for(int i = 1; i <= n; ++i) w[i] = i % (M - 1) + 1;
	
	if(m == 1) {
		int _ans = 0;
		for(int i = 1; i <= n; ++i)
			upd(_ans, c[i]);
		std::cout << _ans << "\n";
	} else {
		int _ans = 0;
		for(int T = 666; T--; )
			upd(_ans, solve());
		if(_ans == 0) std::cout << "-1\n";
		else std::cout << _ans << "\n";
	}
}
posted @ 2025-09-30 23:18  CuteNess  阅读(8)  评论(0)    收藏  举报