洛谷P4093 [HEOI2016/TJOI2016] 序列 题解 树套树(树状数组 套 splay tree)

题目链接:https://www.luogu.com.cn/problem/P4093

解题思路完全来自 一剑霜寒十四洲 大佬的博客

大题思路是:

我们设3个数组:

  • \(a[i]\) 表示原来第 \(i\) 个位置上的值。
  • \(maxa[i]\) 表示第 \(i\) 个位置上可以变成的最大值。
  • \(mina[i]\) 表示第 \(i\) 个位置上可以变成的最小值。

要满足在任意一种变化中,选出的子序列中第 \(i\) 个位置的上一个位置 \(j\) 是符合要求的,需要满足:

  1. \(j \lt i\) 这一条很显然。
  2. \(maxa[j] \le a[i]\)\(j\) 的位置上的数变成最大值时序列仍然不降。
  3. \(a[j] \le mina[i]\)\(i\) 的位置上的数变成最小值时序列仍然不降。

于是一个dp就很显然了:

\(f[i]= \max(f[j])\)\(j\) 要满足上述条件。

可以发现,有1,2,3这3条要求,不就是一个三维偏序问题吗?跟 陌上花开 那道题非常像。

首先从小到大枚举 \(i\),可以降掉第一维。

第二维和第三维直接树套树搞定。


实现的时候(因为我是树状数组 套 splay tree),在 splay tree 的每个节点除了数值 v 之外,还需要额外维护两个信息:

  • \(f\):这个点对应的 \(f_i\)
  • \(maxf\):以这个点为根的子树中所有 \(f\) 的最大值。

示例程序:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 5, N = 1e5, inf = 1e9;

struct SplayTree {

    struct Node {
        int s[2], v, f, maxf, p, sz;

        Node() {}
        Node(int _v, int _p, int _f) {
            s[0] = s[1] = 0;
            v = _v;
            p = _p;
            maxf = f = _f;
            sz = 1;
        }
    } tr[maxn*30];

    int idx;

    void push_up(int u) {
        int l = tr[u].s[0], r = tr[u].s[1];
        tr[u].sz = tr[l].sz + tr[r].sz + 1;
        tr[u].maxf = max({ tr[u].f, tr[l].maxf, tr[r].maxf });
    }

    void f_s(int p, int u, int k) {
        tr[p].s[k] = u;
        tr[u].p = p;
    }

    void rot(int x) {
        int y = tr[x].p, z = tr[y].p;
        int k = tr[y].s[1] == x;
        f_s(z, x, tr[z].s[1] == y);
        f_s(y, tr[x].s[k^1], k);
        f_s(x, y, k^1);
        push_up(y), push_up(x);
    }

    void splay(int &root, int x, int k) {
        while (tr[x].p != k) {
            int y = tr[x].p, z = tr[y].p;
            if (z != k)
                (tr[y].s[1]==x)^(tr[z].s[1]==y) ? rot(x) : rot(y);
            rot(x);
        }
        if (!k) root = x;
    }

    int get_rnk(int &root, int x) {
        int maxf = 0, u = root, p = 0;
        while (u) {
            p = u;
            if (tr[u].v <= x) {
                maxf = max({ maxf, tr[u].f, tr[ tr[u].s[0] ].maxf });
                u = tr[u].s[1];
            }
            else
                u = tr[u].s[0];
        }
        if (p) splay(root, p, 0);
        return maxf; // 减去一个哨兵节点
    }

    void ins(int &root, int v, int f) {
        int u = root, p = 0, k = 0;
        while (u) {
            tr[u].sz++;
            k = tr[u].v < v;
            p = u;
            u = tr[u].s[k];
        }
        u = ++idx;
        tr[u] = Node(v, p, f);
        if (p) tr[p].s[k] = u;
        splay(root, u, 0);
    }

    void del(int &root, int v) {
        int u = root;
        while (u) {
            if (tr[u].v == v) break;
            else u = tr[u].s[tr[u].v < v];
        }
        splay(root, u, 0);
        int l = tr[u].s[0], r = tr[u].s[1];
        if (!l || !r) {
            root = l + r;
            tr[root].p = 0;
        }
        else {
            while (tr[l].s[1]) l = tr[l].s[1];
            splay(root, l, 0);
            tr[l].s[1] = r;
            tr[r].p = l;
            push_up(l);
        }
    }

} splay_t;

struct BIT {

    int tr[maxn];

    int lowbit(int x) { return x & -x; }

    void init() {
        for (int i = 1; i <= N; i++) {
            splay_t.ins(tr[i], -inf, 0);
            splay_t.ins(tr[i], inf, 0);
        }
    }

    void add(int p, int val, int f) {
        for (int i = p; i <= N; i += lowbit(i))
            splay_t.ins(tr[i], val, f);
    }

    int query(int p, int val) {
        int res = 0;
        for (int i = p; i; i -= lowbit(i))
            res = max(res, splay_t.get_rnk(tr[i], val));
        return res;
    }

} bit;

int n, m, a[maxn], maxa[maxn], mina[maxn], f[maxn], ans;

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++) {
        scanf("%d", a+i);
        maxa[i] = mina[i] = a[i];
    }
    for (int i = 0, x, y; i < m; i++) {
        scanf("%d%d", &x, &y);
        maxa[x] = max(maxa[x], y);
        mina[x] = min(mina[x], y);
    }
    bit.init();
    for (int i = 1; i <= n; i++) {
        f[i] = bit.query(a[i], mina[i]) + 1;
        ans = max(ans, f[i]);
        bit.add(maxa[i], a[i], f[i]);
    }
    printf("%lld\n", ans);
    return 0;
}
posted @ 2026-01-08 18:36  quanjun  阅读(6)  评论(0)    收藏  举报