UOJ #919. 【UR #28】环环相扣 题解

Description

给定一个长度为 \(n\) 的整数序列 \(a_1\sim a_n\),其中的元素两两互不相等。

\(q\) 个询问,每个询问给定一个区间 \([l,r]\),你要选择三个下标 \(i,j,k\in[l,r]\) 满足 \(i\neq j,j\neq k,k\neq i\),最大化 \((a_i\bmod a_j)+(a_j\bmod a_k)+(a_k\bmod a_i)\) 的值。

你只需要输出这个最大值。

\(3\leq n\leq2\times10^6\)\(1\leq q\leq8\times10^5\)\(\text{op}\in\{0,1\}\)\(1\leq a_i\leq10^{18}\)\(a_1\sim a_n\) 互不相等,\(1\leq l\leq r\leq n\)\(r-l+1\geq3\)

Solution

不妨设 \(a_x>a_y>a_z\),那么对于 \((x,y,z)\) 只有两种贡献:\(a_x\bmod a_y+a_y\bmod a_z+a_z\)\(a_x\bmod a_z+a_z+a_y\)

对于一组询问 \([l,r]\),有个结论是 \([l,r]\) 内的区间最大值和次大值都必须选。

  • 证明

    先把区间的数拿出来并排序,使得 \(a_1<a_2<\ldots<a_m\),则选择 \((m-2,m-1,m)\) 可以得到一个答案下界为 \(a_{m-1}+a_{m-2}\)

    • 如果最终答案为 \(a_x\bmod a_y+a_y\bmod a_z+a_z\),由于 \(a_x\bmod a_y+a_y\bmod a_z+a_z\leq\min\{a_x,2\cdot a_y-1\}\),则 \(x\leq m-1\)\(y\leq m-2\) 一定没上面那个优,所以 \(x=m\)\(y=m-1\)

    • 如果最终答案为 \(a_x\bmod a_z+a_z+a_y\),由于 \(a_x\bmod a_z+a_z+a_y\leq a_x+a_y\),当 \(x\neq m\) 一定达不到最优解,又因为 \(a_y\) 只出现了一次,所以 \(y\) 一定尽量取到 \(m-1\)

于是 \(x\)\(y\) 就固定了,设 \(F(x,l,r)\) 表示将 \([l,r]\) 中的 \(x\) 和把剩下的最大值去掉后的所有 \(k\)\(a_x\bmod a_k+a_k\) 的最大值。

那么答案就是 \(\max\{F(x,l,r)+a_y,F(y,l,r)+a_x\bmod a_y\}\),由于 \(x,y\) 已经确定,所以我们只需要求出 \(F(x,l,r)\) 的值即可。


考虑将 \(F(x,l,r)\) 拆成 \(F(x,l,x-1)\)\(F(x,x+1,r)\)。对于 \(F(x,l,x-1)\),让 \(l\)\(x\) 枚举到 \(1\) 可以得到一个 \(O(n^2)\) 的做法。

又有个结论是如果扫到了某个 \(l\),如果存在至少两个 \(a_i>\frac{a_x}{2}\) 就可以停止扫描。

  • 证明

    不妨设这两个数是 \(a_i,a_j\)\(a_i<a_j\)

    • 如果 \(a_i>a_x\),与 \(x\) 为区间最大值/次大值矛盾,这个区间一定不会被询问到。

    • 如果 \(\frac{a_x}{2}<a_i<a_x\),则 \(a_i\bmod a_x+a_x=a_i\),已经到了最大值,前面的一定不会更优。

基于这个做法暴力枚举 \(l\) 可以做到 \(O(n\log V+q\log n)\),过不了。


还有个结论是扫描到 \(a_i\) 时,如果已经存在两个数 \(\geq 2\cdot a_i\),则 \(a_i\) 就可以删掉。

  • 证明

    如果 \(a_j\geq 2\cdot a_i\),则 \(a_x\bmod a_j+a_j\geq a_j\geq 2\cdot a_i>a_x\bmod a_i+a_i\),所以 \(a_i\) 一定不会对答案造成贡献。

所以可以在从小到大枚举 \(x\) 的过程中,维护一个栈表示目前还没删掉的数和这些数的删除标记。然后在扫描 \(l\) 的过程中,维护另一个标记表示 \(>\frac{a_x}{2}\) 的个数。如果当前 \(a_i\leq \frac{a_x}{2}\) 就将 \(i\) 的删除标记加 \(1\)。否则将另一个标记加一,如果另一个标记到了 \(2\) 就停止扫描,并把栈里面删除标记为 \(2\) 的数删掉,并将 \(x\) 加到栈里。

容易证明上面那个做法的预处理复杂度为 \(O(n)\)

时间复杂度:\(O(n+q\log n)\)

Code

#include <bits/stdc++.h>

// #define int int64_t

const int kMaxN = 2e6 + 5;

int n, q, tp;
int64_t a[kMaxN];
std::vector<std::pair<int, int64_t>> vecl[kMaxN], vecr[kMaxN];

int get(int x, int y) { return a[x] > a[y] ? x : y; }

struct SGT {
  int N, mx[kMaxN * 4];

  void pushup(int x) {
    mx[x] = get(mx[x << 1], mx[x << 1 | 1]);
  }

  void build(int n) {
    for (N = 1; N <= n + 1; N <<= 1) {}
    for (int i = N; i <= N + n; ++i) mx[i] = i - N;
    for (int i = N - 1; i; --i) pushup(i);
  }

  int query(int l, int r) {
    int ret = 0;
    for (l += N - 1, r += N + 1; l ^ r ^ 1; l >>= 1, r >>= 1) {
      if (~l & 1) ret = get(ret, mx[l ^ 1]);
      if (r & 1) ret = get(ret, mx[r ^ 1]);
    }
    return ret;
  }
} sgt;

void getl() {
  static int stk[kMaxN] = {0}, cnt[kMaxN] = {0}, tmp[kMaxN];
  int top = 0;
  for (int i = 1; i <= n; ++i) {
    int now = 0, mx = 0, cur = top + 1;
    int64_t res = LLONG_MIN;
    for (int j = top; j; --j) {
      if (!mx) {
        mx = stk[j];
      } else if (a[stk[j]] < a[mx]) {
        res = std::max(res, a[i] % a[stk[j]] + a[stk[j]]);
      } else {
        res = std::max(res, a[i] % a[mx] + a[mx]);
        mx = stk[j];
      }
      vecl[i].emplace_back(stk[j], res);
      if (a[stk[j]] > a[i] / 2) {
        if (++now == 2) break;
      } else {
        ++cnt[stk[j]];
      }
      cur = j;
    }
    int m = 0;
    for (int j = cur; j <= top; ++j)
      if (cnt[stk[j]] < 2)
        tmp[++m] = stk[j];
    top = cur - 1;
    for (int j = 1; j <= m; ++j) stk[++top] = tmp[j];
    stk[++top] = i;
    std::reverse(vecl[i].begin(), vecl[i].end());
  }
}

void getr() {
  static int stk[kMaxN] = {0}, cnt[kMaxN] = {0}, tmp[kMaxN];
  int top = 0;
  for (int i = n; i; --i) {
    int now = 0, mx = 0, cur = top + 1;
    int64_t res = LLONG_MIN;
    for (int j = top; j; --j) {
      if (!mx) {
        mx = stk[j];
      } else if (a[stk[j]] < a[mx]) {
        res = std::max(res, a[i] % a[stk[j]] + a[stk[j]]);
      } else {
        res = std::max(res, a[i] % a[mx] + a[mx]);
        mx = stk[j];
      }
      vecr[i].emplace_back(stk[j], res);
      if (a[stk[j]] > a[i] / 2) {
        if (++now == 2) break;
      } else {
        ++cnt[stk[j]];
      }
      cur = j;
    }
    int m = 0;
    for (int j = cur; j <= top; ++j)
      if (cnt[stk[j]] < 2)
        tmp[++m] = stk[j];
    top = cur - 1;
    for (int j = 1; j <= m; ++j) stk[++top] = tmp[j];
    stk[++top] = i;
    std::reverse(vecr[i].begin(), vecr[i].end());
  }
}

void prework() {
  sgt.build(n);
  getl(), getr();
}

int64_t F(int x, int l, int r, int mx) {
  int64_t ret = LLONG_MIN;
  int y = sgt.query(l, x - 1), z = sgt.query(x + 1, r);
  if (y && y != mx) ret = std::max(ret, a[y] + a[x] % a[y]);
  if (z && z != mx) ret = std::max(ret, a[z] + a[x] % a[z]);
  auto it1 = std::lower_bound(vecl[x].begin(), vecl[x].end(), std::pair<int, int64_t>{l, LLONG_MIN});
  auto it2 = std::lower_bound(vecr[x].begin(), vecr[x].end(), std::pair<int, int64_t>{r, LLONG_MAX}, std::greater<>());
  if (it1 != vecl[x].end()) ret = std::max(ret, it1->second);
  if (it2 != vecr[x].end()) ret = std::max(ret, it2->second);
  return ret;
}

int64_t query(int l, int r) {
  int x = sgt.query(l, r), y = get(sgt.query(l, x - 1), sgt.query(x + 1, r));
  return std::max(F(x, l, r, y) + a[y], F(y, l, r, x) + a[x] % a[y]);
}

void dickdreamer() {
  std::cin >> n >> q >> tp;
  for (int i = 1; i <= n; ++i) std::cin >> a[i];
  prework();
  int64_t lastans = 0;
  for (int i = 1; i <= q; ++i) {
    int l, r;
    std::cin >> l >> r;
    l = (l + lastans * tp - 1) % n + 1;
    r = (r + lastans * tp - 1) % n + 1;
    std::cout << (lastans = query(l, r)) << '\n';
  }
}

int32_t main() {
#ifdef ORZXKR
  freopen("in.txt", "r", stdin);
  freopen("out.txt", "w", stdout);
#endif
  std::ios::sync_with_stdio(0), std::cin.tie(0), std::cout.tie(0);
  int T = 1;
  // std::cin >> T;
  while (T--) dickdreamer();
  // std::cerr << 1.0 * clock() / CLOCKS_PER_SEC << "s\n";
  return 0;
}
posted @ 2024-11-26 12:30  下蛋爷  阅读(171)  评论(0)    收藏  举报