CF2003F. Turtle and Three Sequences
给你三个长为 \(n\) 的序列 \(a,b,c\)。
求所有满足一下条件的 \([1,2,\cdots,n]\) 的长为 \(m\) 的子序列 \(p_1,p_2,\cdots,p_m\) 中,\(\sum_{i=1}^m c_{p_i}\) 的最大值
- \(a_{p_1}\le a_{p_2}\le\cdots\le a_{p_m}\)。
- \(\forall i \neq j,b_{p_i} \neq b_{p_j}\),即 \(b_i\) 互异。
\(1\le a_i,b_i\le n\le 3\times 10^3,1\le m\le 5\)
我们发现 \(b_i\) 互不相同这个条件非常难办,所以肯定要把它给转化掉。
有两种方法。
第一种是我们考虑,假如把第二个条件改成 \(b\) 单增,那么就是二维偏序,可以 \(\mathcal O(nm\log^2 n)\) 解决。
那么我们每次把 \(b\) 映射到一个随机的排列,有 \(\frac{1}{m!}\approx .008\) 的概率答案中的 \(b\) 被映射到单增序列上。
那么做多次即可,时间复杂度 \(\mathcal O(nm\log^2n \times m!)\)。
第二种是题解做法,我们考虑把每种颜色映射到 \(1\sim m\) 中的随机一种,然后直接状压,复杂度 \(O(n\log n2^m)\)。
正确率为答案中的 \(b\) 刚好被映射到排列上,也就是 \(\frac{m!}{m^m} \approx .03\)。
时间复杂度 \(\mathcal O(2^mn\log n \times \frac{m^m}{m!})\)。
我写的第一种,卡了好久才过 /ll
#include <algorithm>
#include <iostream>
#include <random>
const int N = 3001, M = 512;
std::mt19937 rnd(std::random_device{}());
auto upd = [](auto& x, auto&& y) {
x = std::max(x, y);
};
struct Matr {
int matr[N][M];
inline void add(int x, int y, int z) {
for(; x < N; x += x & -x)
for(int j = y; j < M; j += j & -j)
upd(matr[x][j], z);
}
inline int sum(int x, int y) {
int z = 0;
for(; x; x -= x & -x)
for(int j = y; j; j -= j & -j)
upd(z, matr[x][j]);
return z;
}
inline void clear(int x, int y) {
for(; x < N; x += x & -x)
for(int j = y; j < M; j += j & -j)
matr[x][j] = 0;
}
};
Matr Mat[4];
int n, m, a[N], rb[N], b[N], c[N], w[N], q[N];
int solve() {
std::shuffle(w + 1, w + n + 1, rnd);
for(int i = 1; i <= n; ++i) b[i] = w[rb[i]];
int ans = 0;
for(int i = 1; i <= n; ++i) {
Mat[0].add(a[i], b[i], c[i]);
for(int t = 1; t < m - 1; ++t)
if(auto val = Mat[t-1].sum(a[i], b[i]-1); val)
Mat[t].add(a[i], b[i], val + c[i]);
if(auto val = Mat[m-2].sum(a[i], b[i]-1); val)
upd(ans, val + c[i]);
}
for(int i = 1; i <= n; ++i)
for(int x = a[i]; x < N; x += x & -x)
for(int y = b[i]; y < M; y += y & -y)
for(int t = 0; t < m - 1; ++t)
Mat[t].matr[x][y] = 0;
return ans;
}
int T = clock();
int main() {
std::ios::sync_with_stdio(0), std::cin.tie(0), std::cout.tie(0);
std::cin >> n >> m;
for(int i = 1; i <= n; ++i) std::cin >> a[i];
for(int i = 1; i <= n; ++i) std::cin >> rb[i];
for(int i = 1; i <= n; ++i) std::cin >> c[i];
for(int i = 1; i <= n; ++i) w[i] = i % (M - 1) + 1;
if(m == 1) {
int _ans = 0;
for(int i = 1; i <= n; ++i)
upd(_ans, c[i]);
std::cout << _ans << "\n";
} else {
int _ans = 0;
for(int T = 666; T--; )
upd(_ans, solve());
if(_ans == 0) std::cout << "-1\n";
else std::cout << _ans << "\n";
}
}
本文来自博客园,作者:CuteNess,转载请注明原文链接:https://www.cnblogs.com/CuteNess/p/19121922

浙公网安备 33010602011771号