算法题 跳跃游戏:倍增+st表+基环树
题目
跳跃游戏
给定一个数组,每个元素代表的节点能够单向跳跃到固定的另一个节点。
给出每个节点能转移的节点,进行q次查询,每次查询给出一个区间和转移次数,游戏规则是可以在给出的区间中选取任意一个节点作为起点跳跃最多x次,需要回答:在此区间内选择任意起点最多跳x次能到达的最大的节点序号是多少。
输入:
第一行输入两个整数 n (1≤n≤50000)和 q (1≤q≤10^5),分别表示节点个数和查询次数。
第二行包含 n 个随机生成的整数 a1, a2, ..., an (1 < ai < n), 表示每个节点能转移到的节点
接下来q行,每行三个整数l(1≤l≤n), r(l≤r≤n), x(0≤x≤n), 表示每次查询可选的起始区间和最多可进行的转移次数
输出:
输出q行,作为每次查询的输出。对于每次查询,输出一个整数,表示在不超过x次转移操作下,能够到达的最大的点位编号。
样例输入:
10 10
4 2 1 8 9 3 9 7 5 5
5 7 2
3 7 6
3 5 9
9 9 9
6 9 3
9 10 3
6 6 3
1 9 6
1 3 10
1 3 3
样例输出:
9
9
9
9
9
10
6
9
9
8
思路
单次查询要在logn时间内完成,考虑用倍增和st表预处理。问题是lrx是动态变化的,不可能对每个x预处理st表,分析如何剪枝。
基环树:节点转移路线构成的结构是内向基环森林(每个节点的出度为1,节点数等于边数,每个连通图都是一个内向基环树),那么对于每个基环树内的节点,存在有限步数可达的最大节点,预处理该信息为max_reach[i]和time_to_max[i]
倍增法:计算go[k][i]为点i走2^k步到达的位置,mx[k][i]为经过的最大节点编号。
st表:对倍增数组mx[k]建表,st_path[k][j][i]表示[i,i+2^j-1]区间内节点走2^k步经过的最大值。st_reach_idx[j][i]表示区间内到达最大节点的起点编号,如果多个点的最大值相同,存储步数更小的
分治和递归剪枝:
对于区间lr,如果x==2^k或x>=n,O(1)查询返回;
否则,利用x的二进制最高位k查询能到达的最大节点ans,并查询区间内节点能抵达的上限节点val,如果val不大于ans,则ans为最大值,返回;
如果上限节点大于ans,且其步数<=x,可达,返回val;
否则利用倍增计算该点实际走x步经过的最大点(O(logx))并更新ans,然后递归处理其左右子区间[l,idx-1]和[idx+1,r]。
记忆化:map记录搜过的答案
时间复杂度
预处理
拓扑排序+树上反向递推,O(n)
倍增数组建logn层,O(nlogn)
st表,对倍增数组建表,O(nlognlogn)
总计O(nlognlogn),约10^7
查询
最好情况O(1)查询命中
最坏情况O(nlogn),所有点的步数都刚好不够,但由于ans的单调剪枝,很难构造出跑满O(nlogn)的数据,且记忆化搜索避免了重复计算
ac代码
#include <iostream>
#include <algorithm>
#include <vector>
#include <cmath>
#include <cstring>
#include <map>
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
const int MAXN = 50005;
const int LOGN = 17;
int n, q;
int a[MAXN];
int go[LOGN][MAXN];
int mx[LOGN][MAXN];
int max_reach[MAXN];
int time_to_max[MAXN];
int deg[MAXN];
int st_path[LOGN][LOGN][MAXN];
int st_reach_idx[LOGN][MAXN];
int lg2[MAXN];
map<long long, int> memo;
void init_max_reach_and_time() {
fill(deg, deg + n + 1, 0);
for (int i = 1; i <= n; ++i) deg[a[i]]++;
static int q_topo[MAXN];
int head = 0, tail = 0;
for (int i = 1; i <= n; ++i) if (deg[i] == 0) q_topo[tail++] = i;
while(head < tail){
int u = q_topo[head++];
int v = a[u];
deg[v]--;
if (deg[v] == 0) q_topo[tail++] = v;
}
for (int i = 1; i <= n; ++i) {
if (deg[i] > 0 && max_reach[i] == 0) {
vector<int> cycle;
int curr = i;
int cycle_max_val = 0;
while(max_reach[curr] == 0) {
max_reach[curr] = -1;
cycle.push_back(curr);
cycle_max_val = max(cycle_max_val, curr);
curr = a[curr];
}
int sz = cycle.size();
for(int u : cycle) max_reach[u] = cycle_max_val;
for(int u : cycle) time_to_max[u] = 2 * n;
int dist = 2 * n;
for(int k = 2 * sz - 1; k >= 0; --k) {
int u = cycle[k % sz];
if (u == cycle_max_val) dist = 0;
else dist++;
time_to_max[u] = min(time_to_max[u], dist);
}
}
}
for (int i = tail - 1; i >= 0; --i) {
int u = q_topo[i];
int v = a[u];
max_reach[u] = max(u, max_reach[v]);
if (u == max_reach[u]) time_to_max[u] = 0;
else time_to_max[u] = time_to_max[v] + 1;
}
}
void init() {
lg2[1] = 0;
for (int i = 2; i <= n; ++i) lg2[i] = lg2[i >> 1] + 1;
for (int i = 1; i <= n; ++i) {
go[0][i] = a[i];
mx[0][i] = max(i, a[i]);
}
for (int k = 1; k < LOGN; ++k) {
for (int i = 1; i <= n; ++i) {
go[k][i] = go[k-1][ go[k-1][i] ];
mx[k][i] = max(mx[k-1][i], mx[k-1][ go[k-1][i] ]);
}
}
init_max_reach_and_time();
for (int k = 0; k < LOGN; ++k) {
for (int i = 1; i <= n; ++i) st_path[k][0][i] = mx[k][i];
for (int j = 1; j < LOGN; ++j) {
int len = 1 << (j - 1);
int limit = n - (1 << j) + 1;
for (int i = 1; i <= limit; ++i) {
st_path[k][j][i] = max(st_path[k][j-1][i], st_path[k][j-1][i + len]);
}
}
}
for (int i = 1; i <= n; ++i) st_reach_idx[0][i] = i;
for (int j = 1; j < LOGN; ++j) {
int len = 1 << (j - 1);
int limit = n - (1 << j) + 1;
for (int i = 1; i <= limit; ++i) {
int idx1 = st_reach_idx[j-1][i];
int idx2 = st_reach_idx[j-1][i + len];
bool pick1 = false;
if (max_reach[idx1] > max_reach[idx2]) pick1 = true;
else if (max_reach[idx1] == max_reach[idx2]) {
if (time_to_max[idx1] <= time_to_max[idx2]) pick1 = true;
}
st_reach_idx[j][i] = pick1 ? idx1 : idx2;
}
}
}
inline int query_path_st(int k, int L, int R) {
if (L > R) return 0;
int j = lg2[R - L + 1];
return max(st_path[k][j][L], st_path[k][j][R - (1 << j) + 1]);
}
inline int query_reach_idx(int L, int R) {
int j = lg2[R - L + 1];
int idx1 = st_reach_idx[j][L];
int idx2 = st_reach_idx[j][R - (1 << j) + 1];
if (max_reach[idx1] > max_reach[idx2]) return idx1;
if (max_reach[idx2] > max_reach[idx1]) return idx2;
return (time_to_max[idx1] <= time_to_max[idx2]) ? idx1 : idx2;
}
inline int calc_val_in_steps(int u, int x) {
int res = u;
for (int k = LOGN - 1; k >= 0; --k) {
if ((x >> k) & 1) {
res = max(res, mx[k][u]);
u = go[k][u];
}
}
return res;
}
void solve_recursive(int L, int R, int limit_x, int &ans) {
if (L > R) return;
int idx = query_reach_idx(L, R);
int potential = max_reach[idx];
if (potential <= ans) return;
if (time_to_max[idx] <= limit_x) {
ans = max(ans, potential);
return;
}
ans = max(ans, calc_val_in_steps(idx, limit_x));
solve_recursive(L, idx - 1, limit_x, ans);
solve_recursive(idx + 1, R, limit_x, ans);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
if (!(cin >> n >> q)) return 0;
for (int i = 1; i <= n; ++i) cin >> a[i];
init();
while (q--) {
int l, r, x;
cin >> l >> r >> x;
if (x == 0) {
cout << r << "\n";
continue;
}
long long key = ((long long)l << 34) | ((long long)r << 17) | x;
if (memo.count(key)) {
cout << memo[key] << "\n";
continue;
}
int k = lg2[x];
int ans = query_path_st(k, l, r);
if (ans == n) {
cout << n << "\n";
continue;
}
solve_recursive(l, r, x, ans);
memo[key] = ans;
cout << ans << "\n";
}
return 0;
}
总结
想不到更简单通用的解法,用了各种优化才ac了

浙公网安备 33010602011771号