线段树进阶

线段树进阶

标记永久化

在主席树等线段树结构中,若每次都把懒标记下传,则需要新建很多点,空间常数过大。

主要思路是不进行 pushdownpushup 操作,每次查询时把一路上的标记信息都整合到一起。

一般的,能标记永久化的信息通常满足:

  • 信息在区间修改后必须支持快速更新。
  • 修改与顺序无关。
    • 区间赋值虽然与顺序有关,但是可以同时维护时间戳转化为求时间戳的最值。

如下是一份区间加+区间求和的模板:P3372 【模板】线段树 1

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 1e5 + 7;

int a[N];

int n, m;

namespace SMT {
ll s[N << 2];
int tag[N << 2];

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

void build(int x, int l, int r) {
    if (l == r) {
        s[x] = a[l];
        return;
    }

    int mid = (l + r) >> 1;
    build(ls(x), l, mid), build(rs(x), mid + 1, r);
    s[x] = s[ls(x)] + s[rs(x)];
}

void update(int x, int nl, int nr, int l, int r, int k) {
    s[x] += 1ll * (min(r, nr) - max(l, nl) + 1) * k;

    if (l <= nl && nr <= r) {
        tag[x] += k;
        return;
    }

    int mid = (nl + nr) >> 1;

    if (l <= mid)
        update(ls(x), nl, mid, l, r, k);

    if (r > mid)
        update(rs(x), mid + 1, nr, l, r, k);
}

ll query(int x, int nl, int nr, int l, int r) {
    if (l <= nl && nr <= r)
        return s[x];

    int mid = (nl + nr) >> 1;
    ll res = 1ll * (min(r, nr) - max(l, nl) + 1) * tag[x];

    if (l <= mid)
        res += query(ls(x), nl, mid, l, r);
    
    if (r > mid)
        res += query(rs(x), mid + 1, nr, l, r);

    return res;
}
} // namespace SMT

signed main() {
    scanf("%d%d", &n, &m);

    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i);

    SMT::build(1, 1, n);

    while (m--) {
        int op, l, r;
        scanf("%d%d%d", &op, &l, &r);

        if (op == 1) {
            int k;
            scanf("%d", &k);
            SMT::update(1, 1, n, l, r, k);
        } else
            printf("%lld\n", SMT::query(1, 1, n, l, r));
    }

    return 0;
}

线段树合并

假设两颗线段树为 \(A\)\(B\) ,考虑从根开始递归合并:

  • 递归到某个节点时,如果 \(A\) 树或者 \(B\) 树上的对应节点为空,直接返回另一个树上对应节点。
  • 如果递归到叶子节点,就合并两棵树上的对应节点。
  • 最后,根据子节点更新当前节点并且返回。

时间复杂度为公共点的数量。

int merge(int a, int b, int l, int r) {
    if (!a || !b)
        return a | b;
    
    if (l == r) {
        // merge in need
        return a;
    }
    
    int mid = (l + r) >> 1;
    lc[a] = merge(lc[a], lc[b], l, mid);
    rc[a] = merge(rc[a], rc[b], mid + 1, r);
    return pushup(a), a;
}

如果仍需要保留原来的信息,则可以通过新建节点的方式处理。

P4556 [Vani有约会] 雨天的尾巴 /【模板】线段树合并

给出一棵树,\(m\) 次操作,每次将树上 \(x \to y\) 路径上的点加入颜色 \(z\) 。最后对于所有点求出该点最多的颜色。

\(n, m \le 10^5\)

考虑用线段树存储每种颜色的出现次数,用树上差分转化为单点加和子树查。对于子树查就直接将每个儿子的线段树与自己合并起来即可。

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 7, LOGN = 17;

struct Graph {
    vector<int> e[N];
    
    inline void insert(int u, int v) {
        e[u].push_back(v);
    }
} G;

int fa[N][LOGN], dep[N], ans[N];

int n, m;

void dfs(int u, int f) {
    fa[u][0] = f, dep[u] = dep[f] + 1;
    
    for (int i = 1; i < LOGN; ++i)
        fa[u][i] = fa[fa[u][i - 1]][i - 1];
    
    for (auto v : G.e[u])
        if (v != f)
            dfs(v, u);
}

inline int LCA(int x, int y) {
    if (dep[x] < dep[y])
        swap(x, y);
    
    for (int h = dep[x] - dep[y]; h; h &= h - 1)
        x = fa[x][__builtin_ctz(h)];
    
    if (x == y)
        return x;
    
    for (int i = LOGN - 1; ~i; --i)
        if (fa[x][i] != fa[y][i])
            x = fa[x][i], y = fa[y][i];
    
    return fa[x][0];
}

namespace SMT {
const int S = 3e7 + 7;

int lc[S], rc[S], cnt[S], ans[S];
int rt[N];

int tot;

inline void pushup(int x) {
    if (cnt[lc[x]] >= cnt[rc[x]])
        cnt[x] = cnt[lc[x]], ans[x] = ans[lc[x]];
    else
        cnt[x] = cnt[rc[x]], ans[x] = ans[rc[x]];
}

void update(int &x, int nl, int nr, int pos, int k) {
    if (!x)
        x = ++tot;
    
    if (nl == nr) {
        cnt[x] += k, ans[x] = nl;
        return ;
    }
    
    int mid = (nl + nr) >> 1;
    
    if (pos <= mid)
        update(lc[x], nl, mid, pos, k);
    else
        update(rc[x], mid + 1, nr, pos, k);
    
    pushup(x);
}

int merge(int a, int b, int l, int r) {
    if (!a || !b)
        return a | b;
    
    if (l == r) {
        cnt[a] += cnt[b], ans[a] = l;
        return a;
    }
    
    int mid = (l + r) >> 1;
    lc[a] = merge(lc[a], lc[b], l, mid);
    rc[a] = merge(rc[a], rc[b], mid + 1, r);
    return pushup(a), a;
}
} // namespace SMT

void dfs(int u) {
    for (auto v : G.e[u])
        if (dep[v] > dep[u])
            dfs(v), SMT::rt[u] = SMT::merge(SMT::rt[u], SMT::rt[v], 1, 1e5);
        
    if (SMT::cnt[SMT::rt[u]])
        ans[u] = SMT::ans[SMT::rt[u]];
}

signed main() {
    scanf("%d%d", &n, &m);
    
    for (int i = 1; i < n; ++i) {
        int u, v;
        scanf("%d%d", &u, &v);
        G.insert(u, v), G.insert(v, u);
    }
    
    dfs(1, 0);
    
    while (m--) {
        int x, y, k;
        scanf("%d%d%d", &x, &y, &k);
        int lca = LCA(x, y);
        SMT::update(SMT::rt[x], 1, 1e5, k, 1);
        SMT::update(SMT::rt[y], 1, 1e5, k, 1);
        SMT::update(SMT::rt[lca], 1, 1e5, k, -1);
        
        if (fa[lca][0])
            SMT::update(SMT::rt[fa[lca][0]], 1, 1e5, k, -1);
    }
    
    dfs(1);
    
    for (int i = 1; i <= n; ++i)
        printf("%d\n", ans[i]);
    
    return 0;
}

线段树分裂

本质是线段树合并的逆过程。线段树分裂只适用于有序的序列,无序的序列是没有意义的,常用在动态开点的权值线段树。

按大小、权值分裂和 fhq-Treap 是类似的:

void split_siz(int a, int &b, int k) {
    if (!a)
        return;
    
    b = newnode();
    
    if (k > s[lc[a]])
        split_siz(rc[a], rc[b], k - s[lc[a]]);
    else {
        rc[b] = rc[a], rc[a] = 0;
        
        if (k < s[lc[a]])
            split_siz(lc[a], lc[b], k);
    }
    
    s[b] = s[a] - k, s[a] = k;
}

从一颗区间为 \([1, n]\) 的线段树中分裂出 \([l, r]\) ,从根开始递归分裂,流程如下:

  • 当节点不存在或者代表的区间 \([s, t]\)\([l, r]\) 不交则直接回溯。
  • \([s, t]\) 包含于 \([l, r]\) 时,将当前结点直接接到新树下面,并断开旧边。
  • \([s, t]\)\([l, r]\) 有交时需要开一个新点。
void split(int &a, int &b, int nl, int nr, int l, int r) {
    if (!a)
        return;

    if (l <= nl && nr <= r) {
        b = a, a = 0;
        return;
    }

    if (!b)
        b = newnode();

    int mid = (nl + nr) >> 1;

    if (l <= mid)
        split(lc[a], lc[b], nl, mid, l, r);

    if (r > mid)
        split(rc[a], rc[b], mid + 1, nr, l, r);

    pushup(a), pushup(b);
}

可以发现被断开的边最多只会有 \(O(\log n)\) 条,所以最终每次分裂的时间复杂度就是 \(O(\log n)\)

P5494 【模板】线段树分裂

给出一个可重集 \(a\)(编号为 \(1\)),它支持以下操作:

  • 0 p x y :将可重集 \(p\)\([x, y]\) 的值移动到一个新的可重集中编号为上一次产生的新可重集的编号 \(+1\)

  • 1 p t :将可重集 \(t\) 中的数放入可重集 \(p\) ,并清空可重集 \(t\)

  • 2 p x q :在可重集 \(p\) 中加入 \(x\) 个数字 \(q\)

  • 3 p x y :查询可重集 \(p\)\([x, y]\) 的值的个数。

  • 4 p k :查询可重集 \(p\) 中第 \(k\) 小的数,不存在时输出 -1

\(n, m \le 2 \times 10^5\)

确实是模板题。

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 2e5 + 7;

int n, m;

namespace SMT {
const int S = 3e7 + 7;

ll cnt[S];
int rt[N], lc[S], rc[S];

int tot;

inline void pushup(int x) {
    cnt[x] = cnt[lc[x]] + cnt[rc[x]];
}

void update(int &x, int nl, int nr, int pos, int k) {
    if (!x)
        x = ++tot;

    cnt[x] += k;

    if (nl == nr)
        return;

    int mid = (nl + nr) >> 1;

    if (pos <= mid)
        update(lc[x], nl, mid, pos, k);
    else
        update(rc[x], mid + 1, nr, pos, k);
}

ll query(int x, int nl, int nr, int l, int r) {
    if (l <= nl && nr <= r)
        return cnt[x];

    int mid = (nl + nr) >> 1;

    if (r <= mid)
        return query(lc[x], nl, mid, l, r);
    else if (l > mid)
        return query(rc[x], mid + 1, nr, l, r);
    else
        return query(lc[x], nl, mid, l, r) + query(rc[x], mid + 1, nr, l, r);
}

int querykth(int x, int nl, int nr, ll k) {
    if (!x)
        return -1;

    if (nl == nr)
        return nl;

    int mid = (nl + nr) >> 1;
    return cnt[lc[x]] >= k ? querykth(lc[x], nl, mid, k) : querykth(rc[x], mid + 1, nr, k - cnt[lc[x]]);
}

int merge(int a, int b) {
    if (!a || !b)
        return a | b;

    cnt[a] += cnt[b];
    lc[a] = merge(lc[a], lc[b]), rc[a] = merge(rc[a], rc[b]);
    return a;
}

void split(int &a, int &b, int nl, int nr, int l, int r) {
    if (!a)
        return;

    if (l <= nl && nr <= r) {
        b = a, a = 0;
        return;
    }

    if (!b)
        b = ++tot;

    int mid = (nl + nr) >> 1;

    if (l <= mid)
        split(lc[a], lc[b], nl, mid, l, r);

    if (r > mid)
        split(rc[a], rc[b], mid + 1, nr, l, r);

    pushup(a), pushup(b);
}
} // namespace SMT

signed main() {
    scanf("%d%d", &n, &m);
    int cnt = 1;

    for (int i = 1; i <= n; ++i) {
        int x;
        scanf("%d", &x);
        SMT::update(SMT::rt[1], 1, n, i, x);
    }

    while (m--) {
        int op;
        scanf("%d", &op);

        if (!op) {
            int x, l, r;
            scanf("%d%d%d", &x, &l, &r);
            SMT::split(SMT::rt[x], SMT::rt[++cnt], 1, n, l, r);
        } else if (op == 1) {
            int x, y;
            scanf("%d%d", &x, &y);
            SMT::merge(SMT::rt[x], SMT::rt[y]);
        } else if (op == 2) {
            int x, c, k;
            scanf("%d%d%d", &x, &c, &k);
            SMT::update(SMT::rt[x], 1, n, k, c);
        } else if (op == 3) {
            int x, l, r;
            scanf("%d%d%d", &x, &l, &r);
            printf("%lld\n", SMT::query(SMT::rt[x], 1, n, l, r));
        } else if (op == 4) {
            int x;
            ll k;
            scanf("%d%lld", &x, &k);
            printf("%d\n", SMT::querykth(SMT::rt[x], 1, n, k));
        }
    }

    return 0;
}

李超线段树

李超线段树支持插入线段/直线,在线查询某个横坐标处的最值。

核心思想是利用标记永久化,在线段树的每个点维护一条优势线段作为该区间的标记,即在 \(mid\) 处取得最值的线段。

普通操作

插入

若该区间无标记,直接打上用该线段更新的标记即可。

否则由于标记难以合并,只能把标记下传。但是子节点也有标记,可能产生冲突。

考虑递归下传标记,按新线段 \(f\) 取值是否大于原标记 \(g\) 可以把当前区间分为两个子区间。在两条线段中,肯定有一条线段,只可能成为左/右区间的答案。具体地,将线段 \(f, g\) 在中点处的值比较。下面只考虑 \(f\) 劣于 \(g\) 情况, \(f\) 更优时交换 \(f, g\) 即可:

  • 若左端点处 \(f\) 优,则 \(f, g\) 必在左区间产生交点,\(f\) 在左区间的一个前缀会优于 \(g\) ,递归左儿子下传标记。
  • 若右端点处 \(f\) 优,则 \(f, g\) 必在右区间产生交点,\(f\) 在右区间的一个后缀会优于 \(g\) ,递归右儿子下传标记。
  • 否则 \(f\) 不可能比 \(g\) 优,无需继续下传。

显然只会下传一边,单次修改时间复杂度 \(O(\log^2 V)\)

void maintain(int x, int l, int r, Line k) {
    int mid = (l + r) >> 1;

    if (cmp(k, s[x], mid))
        swap(s[x], k);

    if (l == r)
        return;

    if (cmp(k, s[x], l))
        maintain(ls(x), l, mid, k);

    if (cmp(k, s[x], r))
        maintain(rs(x), mid + 1, r, k);
}

void update(int x, int nl, int nr, int l, int r, Line k) {
    if (l <= nl && nr <= r) {
        maintain(x, nl, nr, k);
        return;
    }

    int mid = (nl + nr) >> 1;

    if (l <= mid)
        update(ls(x), nl, mid, l, r, k);

    if (r > mid)
        update(rs(x), mid + 1, nr, l, r, k);
}

查询

查询就比较所有包含 \(x\) 的区间的线段得出最终答案,时间复杂度 \(O(\log V)\)

Line query(int x, int nl, int nr, int pos) {
    if (nl == nr)
        return s[x];

    int mid = (nl + nr) >> 1;

    auto pmax = [](Line l1, Line l2, double x) {
        return cmp(l1, l2, x) ? l1 : l2;
    };

    return pmax(s[x], pos <= mid ? query(ls(x), nl, mid, pos) : query(rs(x), mid + 1, nr, pos), pos);
}

扩展

李超线段树支持持久化,支持横坐标范围过大时的动态开点(空间线性),不支持删除(但可以考虑线段树分治)。

李超树合并

类似于普通线段树合并,采用以下过程合并两个李超线段树节点 \(a, b\)\(a\)

  • \(b\) 为空,结束过程。
  • \(a\) 为空,将 \(b\) 复制给 \(a\)
  • \(b\) 对应线段插入 \(a\) 为根的子树。
  • 递归合并左右子树。

若合并若干李超线段树涉及的总点数为 \(n\) ,则总复杂度为 \(O(n \log n)\)

对于任意线段在树上对应点,每次涉及移动它时要么使其深度加一,要么直接删除,这两个操作的代价都是 \(O(1)\) 的,故总复杂度为 \(O(n \log n)\)

int merge(int a, int b, int l, int r) {
    if (!a || !b)
        return a | b;

    int mid = (l + r) >> 1;
    maintain(a, l, r, s[b]);
    lc[a] = merge(lc[a], lc[b], l, mid);
    rc[a] = merge(rc[a], rc[b], mid + 1, r);
    return a;
}

李超线段树合并支持持久化,需要在合并时新建节点。

应用

P4097 【模板】李超线段树 / [HEOI2013] Segment

要求在平面直角坐标系下维护两个操作:

  • 在平面上加入一条线段。记第 \(i\) 条被插入的线段的标号为 \(i\)

  • 给定一个数 \(k\),询问与直线 \(x = k\) 相交的线段中,交点纵坐标最大的线段的编号。

强制在线,\(1 \le n \le 10^5\)\(1 \le k, x_0, x_1 \le 39989\)\(1 \le y_0, y_1 \le 10^9\)

转化问题为:

  • 加入一个一次函数,定义域为 \([l, r]\)
  • 给定 \(k\) ,求定义域包含 \(k\) 的所有一次函数中在 \(x = k\) 处取值最大的,若有多个函数取值相同选编号最小的。

然后就是李超树模板了。

#include <bits/stdc++.h>
using namespace std;
const double eps = 1e-9;
const int Vx = 39989, Vy = 1e9;
const int N = 1e5 + 7;

struct Line {
    double k, b;
    int id;

    inline Line() {}

    inline Line(double x, double y, double xx, double yy, int _id) {
        id = _id;

        if (x == xx)
            b = max(y, yy);
        else
            k = (y - yy) / (x - xx), b = y - k * x;
    }

    inline double operator () (const double x) const {
        return k * x + b;
    }
};

inline int cmp(const double &a, const double &b) {
    if (a - b > eps)
        return 1;
    else if (b - a > eps)
        return -1;
    else
        return 0;
}

inline bool cmp(Line l1, Line l2, double x) {
    int flag = cmp(l1(x), l2(x));
    return flag == 1 || (!flag && l1.id < l2.id);
}

namespace SMT {
Line s[N << 2];

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

void maintain(int x, int l, int r, Line k) {
    int mid = (l + r) >> 1;

    if (cmp(k, s[x], mid))
        swap(s[x], k);

    if (l == r)
        return;

    if (cmp(k, s[x], l))
        maintain(ls(x), l, mid, k);

    if (cmp(k, s[x], r))
        maintain(rs(x), mid + 1, r, k);
}

void update(int x, int nl, int nr, int l, int r, Line k) {
    if (l <= nl && nr <= r) {
        maintain(x, nl, nr, k);
        return;
    }

    int mid = (nl + nr) >> 1;

    if (l <= mid)
        update(ls(x), nl, mid, l, r, k);

    if (r > mid)
        update(rs(x), mid + 1, nr, l, r, k);
}

Line query(int x, int nl, int nr, int p) {
    if (nl == nr)
        return s[x];

    int mid = (nl + nr) >> 1;

    auto pmax = [](Line l1, Line l2, int x) {
        return cmp(l1, l2, x) ? l1 : l2;
    };

    return pmax(s[x], p <= mid ? query(ls(x), nl, mid, p) : query(rs(x), mid + 1, nr, p), p);
}
} // namespace SMT

signed main() {
    int n, cnt = 0, lstans = 0;
    scanf("%d", &n);

    while (n--) {
        int op;
        scanf("%d", &op);

        if (op) {
            int x, y, xx, yy;
            scanf("%d%d%d%d", &x, &y, &xx, &yy);
            x = (x + lstans - 1) % Vx + 1, y = (y + lstans - 1) % Vy + 1;
            xx = (xx + lstans - 1) % Vx + 1, yy = (yy + lstans - 1) % Vy + 1;

            if (x > xx)
                swap(x, xx), swap(y, yy);

            SMT::update(1, 1, Vx, x, xx, Line(x, y, xx, yy, ++cnt));
        } else {
            int x;
            scanf("%d", &x);
            x = (x + lstans - 1) % Vx + 1;
            printf("%d\n", lstans = SMT::query(1, 1, Vx, x).id);
        }
    }

    return 0;
}

P4069 [SDOI2016] 游戏

给出一棵树,初始每个数都是 \(123456789123456789\)\(m\) 次操作:

  • 1 s t a b :将 \(s\)\(t\) 路径上的每个点权值都与 \(dis(s,u)\times a+b\)\(\min\)
  • 2 s t :求 \(s\)\(t\) 路径上的所有点的权值的最小值。

\(n, m \le 10^5\)

树上路径修改考虑树链剖分。记 \(p = \mathrm{LCA}(s, t)\) 修改时分类讨论:

  • \(u \in path(s \to p)\) :此时 \(u\) 的权值为 \(-a \times dis_u + a \times dis_s + b\)
  • \(u \in path(p \to t)\) :此时 \(u\) 的权值为 \(a \times dis_u + a \times (dis_s - 2 dis_p) + b\)

由于一条链上修改时可以保证 \(x\) 坐标( \(dis_u\) )的单调性,所以可以用李超线段树解决。

对于求区间最小值问题,注意到一条线段的最小值显然在端点处取得,判断一下大小取较小者即可。

时间复杂度 \(O(n \log^3 n)\) ,由于树剖和李超树常数都挺小所以效率不错。

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const ll inf = 123456789123456789;
const int N = 1e5 + 7;

struct Graph {
    vector<pair<int, int> > e[N];
    
    inline void insert(int u, int v, int w) {
        e[u].emplace_back(v, w);
    }
} G;

ll dis[N];
int fa[N], dep[N], siz[N], son[N], top[N], dfn[N], id[N];

int n, m, dfstime;

void dfs1(int u, int f) {
    fa[u] = f, dep[u] = dep[f] + 1, siz[u] = 1;

    for (auto it : G.e[u]) {
        int v = it.first, w = it.second;

        if (v == f)
            continue;

        dis[v] = dis[u] + w, dfs1(v, u), siz[u] += siz[v];

        if (siz[v] > siz[son[u]])
            son[u] = v;
    }
}

void dfs2(int u, int topf) {
    top[u] = topf, id[dfn[u] = ++dfstime] = u;

    if (son[u])
        dfs2(son[u], topf);

    for (auto it : G.e[u]) {
        int v = it.first;

        if (v != fa[u] && v != son[u])
            dfs2(v, v);
    }
}

inline int LCA(int x, int y) {
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]])
            swap(x, y);

        x = fa[top[x]];
    }

    return dep[x] < dep[y] ? x : y;
}

struct Line {
    ll k, b;

    inline Line() {}

    inline Line(ll _k, ll _b) : k(_k), b(_b) {}

    inline ll operator () (const int x) const {
        return k * dis[id[x]] + b;
    }
};

namespace SMT {
Line s[N << 2];

ll mn[N << 2];

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

inline void pushup(int x, int l, int r) {
    mn[x] = min(min(mn[ls(x)], mn[rs(x)]), min(s[x](l), s[x](r)));
}

void build(int x, int l, int r) {
    s[x] = Line(0, inf), mn[x] = inf;

    if (l == r)
        return;

    int mid = (l + r) >> 1;
    build(ls(x), l, mid), build(rs(x), mid + 1, r);
}

void maintain(int x, int l, int r, Line k) {
    int mid = (l + r) >> 1;

    if (k(mid) < s[x](mid))
        swap(k, s[x]);

    if (l == r) {
        mn[x] = s[x](l);
        return;
    }

    if (k(l) < s[x](l))
        maintain(ls(x), l, mid, k);

    if (k(r) < s[x](r))
        maintain(rs(x), mid + 1, r, k);

    pushup(x, l, r);
}

void update(int x, int nl, int nr, int l, int r, Line k) {
    if (l <= nl && nr <= r) {
        maintain(x, nl, nr, k);
        return;
    }

    int mid = (nl + nr) >> 1;

    if (l <= mid)
        update(ls(x), nl, mid, l, r, k);

    if (r > mid)
        update(rs(x), mid + 1, nr, l, r, k);

    pushup(x, nl, nr);
}

ll query(int x, int nl, int nr, int l, int r) {
    if (l <= nl && nr <= r)
        return mn[x];

    int mid = (nl + nr) >> 1;
    ll res = min(s[x](max(nl, l)), s[x](min(nr, r)));

    if (l <= mid)
        res = min(res, query(ls(x), nl, mid, l, r));
    
    if (r > mid)
        res = min(res, query(rs(x), mid + 1, nr, l, r));

    return res;
}
} // namespace SMT

inline void update(int x, int y, Line k) {
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]])
            swap(x, y);

        SMT::update(1, 1, n, dfn[top[x]], dfn[x], k);
        x = fa[top[x]];
    }

    if (dep[x] > dep[y])
        swap(x, y);

    SMT::update(1, 1, n, dfn[x], dfn[y], k);
}

inline ll query(int x, int y) {
    ll res = inf;

    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]])
            swap(x, y);

        res = min(res, SMT::query(1, 1, n, dfn[top[x]], dfn[x]));
        x = fa[top[x]];
    }

    if (dep[x] > dep[y])
        swap(x, y);

    res = min(res, SMT::query(1, 1, n, dfn[x], dfn[y]));
    return res;
}

signed main() {
    scanf("%d%d", &n, &m);

    for (int i = 1; i < n; ++i) {
        int u, v, w;
        scanf("%d%d%d", &u, &v, &w);
        G.insert(u, v, w), G.insert(v, u, w);
    }

    dfs1(1, 0), dfs2(1, 1);
    SMT::build(1, 1, n);

    while (m--) {
        int op, x, y;
        scanf("%d%d%d", &op, &x, &y);

        if (op == 1) {
            int a, b, lca = LCA(x, y);
            scanf("%d%d", &a, &b);
            update(x, lca, Line(-a, 1ll * a * dis[x] + b));
            update(lca, y, Line(a, 1ll * a * (dis[x] - dis[lca] * 2) + b));
        } else
            printf("%lld\n", query(x, y));
    }

    return 0;
}

CF932F Escape Through Leaf

有一棵 \(n\) 个点的树,根为 \(1\) 。第 \(i\) 个节点有两个权值 \(a_i, b_i\)

你可以从一个节点跳到它的子树内任意一个节点上。从节点 \(x\) 跳到节点 \(y\) 一次的花费为 \(a_x \times b_y\)。跳跃多次走过一条路径的总费用为每次跳跃的费用之和。分别计算每个点到达子树的每个叶子节点的费用中的最小值。

\(n \le 10^5\)

一个显然的 DP 是设 \(f_u\)\(u\) 的答案,则 \(f_u = \min_{v \in subtree(u)} f_v + a_u \times b_v\) 。这个 DP 可以通过插入子树内所有线段做到,用李超树合并维护可行集合即可。

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const ll inf = 1e18;
const int N = 1e5 + 7;

struct Graph {
    vector<int> e[N];
    
    inline void insert(int u, int v) {
        e[u].emplace_back(v);
    }
} G;

struct Line {
    ll k, b;

    inline Line() {}

    inline Line(ll _k, ll _b) : k(_k), b(_b) {}

    inline ll operator () (const ll &x) const {
        return k * x + b;
    }
};

ll ans[N];
int a[N], b[N];

int n;

namespace SMT {
const int S = N << 5;

Line s[S];

int lc[S], rc[S];
int rt[N];

int tot;

void maintain(int &x, int l, int r, Line k) {
    if (!x) {
        s[x = ++tot] = k;
        return;
    }

    int mid = (l + r) >> 1;

    if (k(mid) < s[x](mid))
        swap(s[x], k);

    if (l == r)
        return;

    if (k(l) < s[x](l))
        maintain(lc[x], l, mid, k);

    if (k(r) < s[x](r))
        maintain(rc[x], mid + 1, r, k);
}

ll query(int x, int nl, int nr, int pos) {
    if (!x)
        return inf;

    if (nl == nr)
        return s[x](pos);

    int mid = (nl + nr) >> 1;
    return min(s[x](pos), pos <= mid ? query(lc[x], nl, mid, pos) : query(rc[x], mid + 1, nr, pos));
}

int merge(int a, int b, int l, int r) {
    if (!a || !b)
        return a | b;

    int mid = (l + r) >> 1;
    maintain(a, l, r, s[b]);
    lc[a] = merge(lc[a], lc[b], l, mid);
    rc[a] = merge(rc[a], rc[b], mid + 1, r);
    return a;
}
} // namespace SMT

void dfs(int u, int f) {
    int sonsum = 0;

    for (int v : G.e[u])
        if (v != f)
            dfs(v, u), ++sonsum, SMT::rt[u] = SMT::merge(SMT::rt[u], SMT::rt[v], -N, N);

    ans[u] = sonsum ? SMT::query(SMT::rt[u], -N, N, a[u]) : 0;
    SMT::maintain(SMT::rt[u], -N, N, Line(b[u], ans[u]));
}

signed main() {
    scanf("%d", &n);

    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i);

    for (int i = 1; i <= n; ++i)
        scanf("%d", b + i);

    for (int i = 1; i < n; ++i) {
        int u, v;
        scanf("%d%d", &u, &v);
        G.insert(u, v), G.insert(v, u);
    }

    dfs(1, 0);

    for (int i = 1; i <= n; ++i)
        printf("%lld ", ans[i]);

    return 0;
}

P5508 寻宝

给定 \(n\) 个点,其中有两类边:

  • 一类边: \(m\) 条区间连区间的边。
  • 二类边:对于任意 \(a_i \ne 0\)\(i\) ,对于其他所有 \(j\) 满足有一条 \(i \to j\) 的边权为 \(|i - j| \times a_i\) 的边。

\(1 \to n\) 的最短路及方案。

\(n \le 5 \times 10^4\)

对于区间连区间的边,直接线段树优化建图即可。

回想 Dijkstra 的过程:找到没有更新过其他点的距离最小的点,用这个点更新所有能够到达的点。

发现每个点都可以通过二类边到达其他所有点,考虑不显示建出二类边,而只考虑二类边造成的影响。

设当前的点为 \(u\) ,不难发现,如果以连向的点的编号为横坐标,距离为纵坐标,那么通过 \(u\) 连向 \(v\) 的二类边对 \(v\) 更新的距离在坐标系上形成两条线段:

  • 左侧 \(x \in [1, u)\)\(dis_x = y = -a_u x + (dis_u + u \times a_u)\)
  • 右侧 \(x \in (u, n]\)\(dis_x = y = a_u x + (dis_u - u \times a_u)\)

因此,二类边实际上是对于对于每个点将其当前距离与一条线段取 \(\min\) ,可以用李超线段树维护。

考虑一棵拥有堆的功能的李超线段树,也就是支持单点删除,查询全局最小值及位置,区间对线段取 \(\min\)

Dijkstra 流程中维护两个堆,一个是普通的一类堆,一个是李超树维护的二类堆,每次比较堆顶选 \(dis\) 较小者松弛其他点即可。

时间复杂度 \(O(n \log^2 n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const ll inf = 0x3f3f3f3f3f3f3f3f;
const int N = 5e4 + 7, M = N * 11;

struct Graph {
    vector<pair<int, int> > e[M];
    
    inline void insert(int u, int v, int w) {
        e[u].emplace_back(v, w);
    }
} G;

struct Line {
    ll k, b;
    int id;

    inline ll operator () (const int x) const {
        return k * x + b;
    }
};

ll dis[M];
int a[N], pre[M];
bool vis[M];

int n, m;

namespace SGT {
int tot;

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

void build(int x, int l, int r) {
    if (l == r) {
        G.insert(x + n * 5, l, 0), G.insert(l, x + n, 0);
        return;
    }

    int mid = (l + r) >> 1;
    build(ls(x), l, mid), build(rs(x), mid + 1, r);
    G.insert(ls(x) + n, x + n, 0), G.insert(rs(x) + n, x + n, 0);
    G.insert(x + n * 5, ls(x) + n * 5, 0), G.insert(x + n * 5, rs(x) + n * 5, 0);
}

void update(int x, int nl, int nr, int l, int r, int k, int op) {
    if (l <= nl && nr <= r) {
        if (op)
            G.insert(k, x + n * 5, 0);
        else
            G.insert(x + n, k, 0);

        return;
    }

    int mid = (nl + nr) >> 1;

    if (l <= mid)
        update(ls(x), nl, mid, l, r, k, op);
    
    if (r > mid)
        update(rs(x), mid + 1, nr, l, r, k, op);
}

inline void insert(int l1, int r1, int l2, int r2, int w) {
    update(1, 1, n, l1, r1, tot, 0), update(1, 1, n, l2, r2, tot + 1, 1);
    G.insert(tot, tot + 1, w), tot += 2;
}
} // namespace SGT

namespace SMT {
struct Node {
    ll val;
    int x, id;

    inline Node() : val(inf), x(0), id(0) {}

    inline Node(Line l, int _x) : val(l(_x)), x(_x), id(l.id) {}

    inline bool operator < (const Node &rhs) const {
        return val < rhs.val;
    }
} mn[M << 2];

Line s[M << 2];

int L[M << 2], R[M << 2];
bool del[M << 2];

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

inline void clear(int x) {
    del[x] = true, mn[x] = Node();
}

inline void pushup(int x) {
    if (del[ls(x)] && del[rs(x)]) {
        clear(x);
        return;
    }

    L[x] = del[ls(x)] ? L[rs(x)] : L[ls(x)];
    R[x] = del[rs(x)] ? R[ls(x)] : R[rs(x)];
    mn[x] = min(min(mn[ls(x)], mn[rs(x)]), min(Node(s[x], L[x]), Node(s[x], R[x])));
}

void build(int x, int l, int r) {
    s[x] = (Line) {0, inf, 0}, mn[x] = Node(), L[x] = l, R[x] = r, del[x] = false;

    if (l == r)
        return;

    int mid = (l + r) >> 1;
    build(ls(x), l, mid), build(rs(x), mid + 1, r);
}

void maintain(int x, int l, int r, Line k) {
    if (del[x])
        return;

    if (k(L[x]) > s[x](L[x]) && k(R[x]) > s[x](R[x]))
        return;

    if (k(L[x]) < s[x](L[x]) && k(R[x]) < s[x](R[x])) {
        s[x] = k;

        if (l == r)
            mn[x] = Node(s[x], l);
        else
            pushup(x);

        return;
    }

    int mid = (l + r) >> 1;

    if (k(mid) < s[x](mid))
        swap(s[x], k);

    if (l == r)
        return;

    if (k(L[x]) < s[x](L[x]))
        maintain(ls(x), l, mid, k);

    if (k(R[x]) < s[x](R[x]))
        maintain(rs(x), mid + 1, r, k);

    pushup(x);
}

void update(int x, int nl, int nr, int l, int r, Line k) {
    if (del[x])
        return;

    if (l <= nl && nr <= r) {
        maintain(x, nl, nr, k);
        return;
    }

    int mid = (nl + nr) >> 1;

    if (l <= mid)
        update(ls(x), nl, mid, l, r, k);

    if (r > mid)
        update(rs(x), mid + 1, nr, l, r, k);

    pushup(x);
}

void remove(int x, int l, int r, int pos) {
    if (l == r) {
        clear(x);
        return;
    }

    int mid = (l + r) >> 1;

    if (pos <= mid)
        remove(ls(x), l, mid, pos);
    else
        remove(rs(x), mid + 1, r, pos);

    pushup(x);
}
} // namespace SMT

inline void Dijkstra(int S) {
    memset(dis + 1, inf, sizeof(ll) * (n * 9 + m * 2));
    priority_queue<pair<ll, int> > q;
    SMT::build(1, 1, n);
    dis[S] = 0, q.emplace(-dis[S], S);

    while (!q.empty() || SMT::mn[1].val != inf) {
        int u;

        if (q.empty() || SMT::mn[1].val < -q.top().first) {
            u = SMT::mn[1].x;

            if (vis[u]) {
                SMT::remove(1, 1, n, u);
                continue;
            }

            dis[u] = SMT::mn[1].val, pre[u] = SMT::mn[1].id;
            SMT::remove(1, 1, n, u);
        } else {
            u = q.top().second, q.pop();

            if (vis[u])
                continue;
        }

        vis[u] = true;

        for (auto it : G.e[u]) {
            int v = it.first, w = it.second;

            if (dis[v] > dis[u] + w)
                dis[v] = dis[u] + w, pre[v] = u, q.emplace(-dis[v], v);
        }

        if (u <= n && a[u]) {
            if (u > 1)
                SMT::update(1, 1, n, 1, u - 1, (Line) {-a[u], dis[u] + 1ll * u * a[u], u});

            if (u < n)
                SMT::update(1, 1, n, u + 1, n, (Line) {a[u], dis[u] - 1ll * u * a[u], u});
        }
    }
}

signed main() {
    scanf("%d%d", &n, &m);

    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i);

    SGT::build(1, 1, n), SGT::tot = n * 9 + 1;

    for (int i = 1; i <= m; ++i) {
        int l1, r1, l2, r2, w;
        scanf("%d%d%d%d%d", &l1, &r1, &l2, &r2, &w);
        SGT::insert(l1, r1, l2, r2, w);
    }

    Dijkstra(1);

    if (dis[n] == inf)
        return puts("-1"), 0;

    printf("%lld\n", dis[n]);
    stack<int> path;

    for (int cur = n; cur; cur = pre[cur])
        if (cur <= n)
            path.emplace(cur);

    printf("%d\n", path.size());

    while (!path.empty())
        printf("%d ", path.top()), path.pop();

    return 0;
}

CF1175G Yet Another Partiton Problem

给定 \(a_{1 \sim n}\) ,需要将其划分为 \(k\) 段,每一段的权值为区间长度乘上区间最大值,最小化每段权值和。

\(n \le 2 \times 10^4\)\(k \le 100\)

\(f_{i, j}\) 表示前 \(i\) 个元素划分为 \(j\) 段的答案,则:

\[f_{i, j} = \min_{k \le i} \{ f_{k - 1, j - 1} + (i - k + 1) \times \max_{l = k}^i a_l \} \]

先用单调栈处理后缀 \(\max\) ,这样所有 \(k\) 就被划分为若干段,每段的后缀 \(\max\) 相等。

那么对于每一段,需要维护 \(f_{k, j - 1} - k \times \max_{l = k}^i a_l\) 的最小值,这显然可以维护 \((k, f_{k - 1, j - 1})\) 的凸包,在凸包上二分解决。

由于单调栈的过程会合并若干凸包,使用启发式合并即可。

考虑每段的贡献,其形如 \((i + 1) \times (\max_{l = k}^i a_l) + (\min f_{k, j - 1} - k \times \max_{l = k}^i a_l)\) ,考虑用李超树维护。但是李超树不支持删除,发现单调栈合并都是从末尾开始合并,因此对李超树可持久化即可。

时间复杂度 \(O(nk \log n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int inf = 0x3f3f3f3f;
const int N = 2e4 + 7;

struct Point {
    int x, y;

    inline friend double slope(Point a, Point b) {
        return (double)(a.y - b.y) / (a.x - b.x);
    }
};

struct ConvexHull {
    deque<Point> q;

    inline void emplace_back(Point k) {
        while (q.size() >= 2 && slope(q[q.size() - 2], q.back()) >= slope(q.back(), k))
            q.pop_back();

        q.emplace_back(k);
    }

    inline void emplace_front(Point k) {
        while (q.size() >= 2 && slope(k, q[0]) >= slope(q[0], q[1]))
            q.pop_front();

        q.emplace_front(k);
    }

    inline void merge(ConvexHull &rhs) {
        if (q.size() > rhs.q.size()) {
            for (int i = rhs.q.size() - 1; ~i; --i)
                emplace_front(rhs.q[i]);
        } else {
            for (Point it : q)
                rhs.emplace_back(it);

            swap(q, rhs.q);
        }
    }

    inline int query(int k) {
        int l = 0, r = q.size() - 2, p = q.size() - 1;

        while (l <= r) {
            int mid = (l + r) >> 1;

            if (slope(q[mid], q[mid + 1]) >= k)
                p = mid, r = mid - 1;
            else
                l = mid + 1;
        }

        return q[p].y - k * q[p].x;
    }
} cvh[N];

int a[N], f[N], g[N], sta[N];

int n, m;

struct Line {
    int k, b;

    inline Line() : k(0), b(inf) {}

    inline Line(const int _k, const int _b) : k(_k), b(_b) {}

    inline int operator () (const int &x) const {
        return k * x + b;
    }
};

namespace SMT {
const int S = N << 5;

Line li[S];

int lc[S], rc[S], rt[N];

int tot;

int insert(int x, int l, int r, Line k) {
    int y = ++tot;
    lc[y] = lc[x], rc[y] = rc[x], li[y] = li[x];
    int mid = (l + r) >> 1;

    if (k(mid) < li[y](mid))
        swap(li[y], k);

    if (l == r)
        return y;

    if (k(l) < li[y](l))
        lc[y] = insert(lc[x], l, mid, k);

    if (k(r) < li[y](r))
        rc[y] = insert(rc[x], mid + 1, r, k);

    return y;
}

int query(int x, int nl, int nr, int p) {
    if (!x)
        return inf;

    if (nl == nr)
        return li[x](p);

    int mid = (nl + nr) >> 1;
    return min(li[x](p), p <= mid ? query(lc[x], nl, mid, p) : query(rc[x], mid + 1, nr, p));
}
} // namespace SMT

signed main() {
    scanf("%d%d", &n, &m);

    for (int i = 1, mx = -1; i <= n; ++i)
        scanf("%d", a + i), f[i] = 1ll * i * (mx = max(mx, a[i]));

    for (int i = 2; i <= m; ++i) {
        memcpy(g + 1, f + 1, sizeof(int) * n);
        SMT::tot = 0;

        for (int j = i, top = 0; j <= n; ++j) {
            ConvexHull now;
            now.emplace_back((Point) {j, g[j - 1]});

            while (top && a[sta[top]] <= a[j])
                now.merge(cvh[top--]);

            sta[++top] = j, swap(cvh[top], now);
            SMT::rt[top] = SMT::insert(SMT::rt[top - 1], 2, n + 1, Line(a[j], cvh[top].query(a[j])));
            f[j] = SMT::query(SMT::rt[top], 2, n + 1, j + 1);
        }
    }

    printf("%d", f[n]);
    return 0;
}

兔队线段树

通常用于维护前缀最值相关问题,下面以前缀最大值为例展开叙述。

对于线段树上的每个点 \(x\) 维护:

  • \(mx_x\) :区间最大值。

  • \(ans_x\) :以左子树的区间最大值为初始的前缀最大值时,右子树内所有前缀最大值的信息和。特别地,叶子无定义。

定义函数 \(\mathrm{calc}(x, k)\) 表示以 \(k\) 为初始的前缀最大值时,区间内所有前缀最大值的信息和。

  • \(x\) 为叶节点,则 \(\mathrm{calc}(x, k) = [k < mx_x]\)
  • \(k < mx_{lc_x}\) ,则 \(\mathrm{calc}(x, k) = calc(lc_x, k) + ans_x\)
  • \(k \ge mx_{lc_x}\) ,则 \(\mathrm{calc}(x, k) = \mathrm{calc}(rc_x, k)\)

不难发现调用一次 \(calc\) 函数复杂度为 \(O(\log n)\)

pushup 时有 \(ans_x = \mathrm{calc}(rc_x, mx_{lc_x})\) ,于是单点修改的复杂度为 \(O(\log^2 n)\)

答案即为 \(calc(x, - \infty)\)

P4198 楼房重建

\(n\) 栋楼,第 \(i\) 栋楼可以抽象成一条两端点为 \((i, 0)\)\((i, h_i)\) 的线段。

初始时 \(h_i\) 均为 \(0\),要支持动态修改单点的 \(h_i\)

每次询问从 \((0, 0)\) 点可以看到多少栋楼房。

能看到一栋楼 \(i\) 当且仅当 \(h_i > 0\)\((0, 0)\)\((i, h_i)\) 的连线上不经过其它楼房。

\(n, m \le 10^5\)

\(s_i = \frac{h_i}{i}\) ,即 \((0, 0)\)\((i, h_i)\) 的斜率,特别的定义 \(s_0 = 0\) 。则一栋楼房 \(i\) 能被看见,当且仅当 \(\max_{j = 0}^{i - 1} x_j < x_i\) ,也就是说它是 \(s_i\) 的前缀严格最大值,剩下的就是套板子了。

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 7;

int n, m;

namespace SMT {
double mx[N << 2];
int ans[N << 2];

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

void build(int x, int l, int r) {
    mx[x] = 0, ans[x] = 1;

    if (l == r)
        return;

    int mid = (l + r) >> 1;
    build(ls(x), l, mid), build(rs(x), mid + 1, r);
}

int calc(int x, int l, int r, double k) {
    if (l == r)
        return k < mx[x];

    int mid = (l + r) >> 1;
    return k < mx[ls(x)] ? calc(ls(x), l, mid, k) + ans[x] : calc(rs(x), mid + 1, r, k);
}

inline void pushup(int x, int l, int r) {
    int mid = (l + r) >> 1;
    mx[x] = max(mx[ls(x)], mx[rs(x)]);
    ans[x] = calc(rs(x), mid + 1, r, mx[ls(x)]);
}

void update(int x, int nl, int nr, int pos, double k) {
    if (nl == nr) {
        mx[x] = k;
        return;
    }

    int mid = (nl + nr) >> 1;

    if (pos <= mid)
        update(ls(x), nl, mid, pos, k);
    else
        update(rs(x), mid + 1, nr, pos, k);

    pushup(x, nl, nr);
}
} // namespace SMT

signed main() {
    scanf("%d%d", &n, &m);
    SMT::build(1, 1, n);

    while (m--) {
        int x, k;
        scanf("%d%d", &x, &k);
        SMT::update(1, 1, n, x, (double)k / x);
        printf("%d\n", SMT::calc(1, 1, n, 0));
    }

    return 0;
}

QOJ5098. 第一代图灵机

给定非负整数序列 \(a_{1 \sim n}\) ,每个位置有一个颜色 \(c_i \in [1, m]\)\(q\) 次操作,操作有:

  • 1 l r :询问区间 \([l, r]\) 中没有重复颜色且数字和最大的子区间的数字和。
  • 2 x k :修改 \(c_x\)\(k\)

\(n, m, q \le 2 \times 10^5\)

\(s\)\(a\) 的前缀和,\(lst_i\) 表示上一个与 \(i\) 颜色相同的位置。对于一个询问 \([l, r]\) ,则答案为:

\[\max_{i = l}^r (s_i - s_{\max(l - 1, \max_{j \le i} lst_j)}) \]

\(lst_i\) 的前缀最大值结构容易想到兔队线段树。对线段树上的点 \(x\) 维护:

  • \(mx_x\) 表示区间内 \(lst\) 的最大值。
  • \(ans_x\) 表示以左子树的区间最大值为初始的前缀最大值时,右子树内 \(s_i - s_{\max_{j \le i} lst_j}\) 的最大值。

定义函数 \(\mathrm{calc}(x, l, r, k)\) 表示以线段树上点 \(x\) 表示的区间 \([l, r]\) 内的点为右端点,左端点 \(\ge k\) 的区间和最大值。分类讨论 \(\mathrm{calc}(x, l, r, k)\) 的计算:

  • 叶子节点:\(\mathrm{calc}(x, l, r, k) = s_l - s_{\max(lst_l, k)}\)
  • \(k \le mx_{lc_x}\)\(\mathrm{calc}(x, l, r, k) = \max(\mathrm{calc}(lc_x, l, mid, k), ans_{rc_x})\)
  • \(k > mx_{lc_x}\)\(\mathrm{calc}(x, l, r, k) = \max(s_{mid} - s_k, \mathrm{calc}(rc_x, mid + 1, r, k))\)

由于维护信息的最大值并非在严格前缀最大值处取得,只是套了一个前缀最大值结构,因此 \(k > mx_{lc_x}\) 的部分细节有些差异,时间复杂度 \(O((n + q) \log^2 n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 2e5 + 7;

set<int> st[N];

ll s[N];
int a[N], c[N], lst[N];

int n, m, q;

namespace SMT {
ll ans[N << 2];
int mx[N << 2];

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

ll calc(int x, int l, int r, int k) {
    if (l == r)
        return s[l] - s[max(lst[l], k)];

    int mid = (l + r) >> 1;
    return k <= mx[ls(x)] ? max(calc(ls(x), l, mid, k), ans[x]) : max(s[mid] - s[k], calc(rs(x), mid + 1, r, k));
}

inline void pushup(int x, int l, int r) {
    int mid = (l + r) >> 1;
    mx[x] = max(mx[ls(x)], mx[rs(x)]);
    ans[x] = calc(rs(x), mid + 1, r, mx[ls(x)]);
}

void build(int x, int l, int r) {
    if (l == r) {
        mx[x] = lst[l];
        return;
    }

    int mid = (l + r) >> 1;
    build(ls(x), l, mid), build(rs(x), mid + 1, r);
    pushup(x, l, r);
}

void update(int x, int nl, int nr, int p, int k) {
    if (nl == nr) {
        mx[x] = k;
        return;
    }

    int mid = (nl + nr) >> 1;

    if (p <= mid)
        update(ls(x), nl, mid, p, k);
    else
        update(rs(x), mid + 1, nr, p, k);

    pushup(x, nl, nr);
}

ll query(int x, int nl, int nr, int l, int r, int &k) {
    if (l <= nl && nr <= r) {
        ll res = calc(x, nl, nr, k);
        k = max(k, mx[x]);
        return res;
    }

    ll res = 0;
    int mid = (nl + nr) >> 1;

    if (l <= mid)
        res = max(res, query(ls(x), nl, mid, l, r, k));
    
    if (r > mid)
        res = max(res, query(rs(x), mid + 1, nr, l, r, k));

    return res;
}
} // namespace SMT

signed main() {
    scanf("%d%d%d", &n, &m, &q);

    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i), s[i] = s[i - 1] + a[i];

    for (int i = 1; i <= n; ++i) {
        scanf("%d", c + i);
        lst[i] = (st[c[i]].empty() ? 0 : *st[c[i]].rbegin());
        st[c[i]].emplace(i);
    }

    SMT::build(1, 1, n);

    while (q--) {
        int op;
        scanf("%d", &op);

        if (op == 1) {
            int l, r;
            scanf("%d%d", &l, &r);
            int lim = l - 1;
            printf("%lld\n", SMT::query(1, 1, n, l, r, lim));
        } else {
            int x, k;
            scanf("%d%d", &x, &k);
            auto it = st[c[x]].find(x);

            if (next(it) != st[c[x]].end())
                SMT::update(1, 1, n, *next(it), lst[*next(it)] = lst[x]);

            st[c[x]].erase(x), it = st[c[x] = k].emplace(x).first;
            SMT::update(1, 1, n, x, lst[x] = (it == st[k].begin() ? 0 : *prev(it)));

            if (next(it) != st[k].end())
                SMT::update(1, 1, n, *next(it), lst[*next(it)] = x);
        }
    }

    return 0;
}

zkw 线段树

常数相较于普通线段树具有显著的优势,但是不能处理带有运算优先级的问题,需要使用标记永久化。

下面以区间加、区间求和的问题(P3372 【模板】线段树 1)为例引入。

考虑先把线段树填充成满二叉树:

inline void build() {
    for (m = 1; m <= n + 1; m <<= 1);
    
    for (int i = 1; i <= n; ++i)
        sum[i + m] = a[i];
    
    for (int i = m; i; --i)
        sum[i] = sum[i << 1] + sum[i << 1 | 1];
}

注:建树的一些看起来不必要的操作是为了便于后面的修改、查询操作。

对于单点操作,直接找到叶子节点,然后一直 pushup 上去即可。

inline void update(int x, int k) {
    for (x += m; x; x >>= 1)
        sum[x] += k;
}

inline int query(int x, int k) {
    int res = 0;
    
    for (x += m; x; x >>= 1)
        res += sum[x];
    
    return res;
}

对于区间操作,考虑分别在区间左端点 \(-1\) 和右端点 \(+1\) 处放两个指针,记为 \(nl, nr\) 。不断令这两个指针一直跳父亲,直到父亲相同为止,过程中:

  • 若左指针是从左儿子跳上来的,那么处理左指针的右子树的贡献。
  • 若右指针是从右儿子跳上来的,那么处理右指针的左子树的贡献。
inline void update(int l, int r, int k) {
    int nl = l + m - 1, nr = r + m + 1, lenl = 0, lenr = 0, len = 1;
    
    while (nl ^ nr ^ 1) {
        sum[nl] += k * lenl, sum[nr] += k * lenr;
        
        if (~nl & 1)
            tag[nl ^ 1] += k, sum[nl ^ 1] += k * len, lenl += len;
        
        if (nr & 1)
            tag[nr ^ 1] += k, sum[nr ^ 1] += k * len, lenr += len;
        
        nl >>= 1, nr >>= 1, len <<= 1;
    }

    for (; nl && nr; nl >>= 1, nr >>= 1)
        sum[nl] += 1ll * k * lenl, sum[nr] += 1ll * k * lenr;
}

inline ll query(int l, int r) {
    ll res = 0;
    int nl = l + m - 1, nr = r + m + 1, lenl = 0, lenr = 0, len = 1;
    
    while (nl ^ nr ^ 1) {
        res += tag[nl] * lenl + tag[nr] * lenr;
        
        if (~nl & 1)
            res += sum[nl ^ 1], lenl += len;
        
        if (nr & 1)
            res += sum[nr ^ 1], lenr += len;
        
        nl >>= 1, nr >>= 1, len <<= 1;
    }

    for (; nl && nr; nl >>= 1, nr >>= 1)
        res += tag[nl] * lenl + tag[nr] * lenr;
    
    return res;
}

势能线段树

势能分析:基于势能来分析线段树上的一种类似剪枝的优化方法的复杂度。

这类题目通常需要构造一个势能函数,发现某些时候区间只能暴力递归修改,而有的时候(通常是势能降为 \(0\) )整个区间可以打标记一起处理。

复杂度分析就是尝试证明势能消长和时间消耗有关,最终通过求和得到复杂度。

注意实现时需要在不能做出有效修改时返回。

常见模型:

  • 区间开方:一个正整数 \(x\) 被开方 \(O(\log \log x)\) 次后会变成 \(1\) ,区间最大值为 \(1\) 时则返回。
  • 区间取模:一个正整数 \(x\)\(p\) 有效取模后至少变小一半,区间最大值 \(< p\) 时则返回。
  • 区间除法:整数一个数 \(p\) 会使极差至少变小一半,\(\lfloor \frac{max}{p} \rfloor = \lfloor \frac{min}{p} \rfloor\) 时则打加法标记。
  • 区间按位与(按位或):一个正整数 \(x\) 被有效按位与 \(O(\log x)\) 次后会变成 \(0\) ,若区间所有有 \(1\) 的位与 \(x\)\(0\) 的位不交时则返回。

P4145 上帝造题的七分钟 2 / 花神游历各国

区间开根(下取整)、区间求和。

\(n, m \le 10^5\)

注意到一个数 \(x\) 开根 \(O(\log \log x)\) 次后变为 \(1\) ,之后 \(0, 1\) 开根则不变。

考虑对于线段树上的一个点,若这个点表示的区间只有 \(0\)\(1\) ,则开根操作就没用了。

若到达叶节点,必然造成一次开根。总共的开根次数是 \(O(n \log \log V)\)

判断区间是否只有 \(0\)\(1\) 只要判最大值即可,总时间复杂度 \(O(n \log n \log \log V)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 1e5 + 7;

ll a[N];

int n, m;

namespace SMT {
ll s[N << 2], mx[N << 2];

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

inline void pushup(int x) {
    s[x] = s[ls(x)] + s[rs(x)];
    mx[x] = max(mx[ls(x)], mx[rs(x)]);
}

void build(int x, int l, int r) {
    if (l == r) {
        s[x] = mx[x] = a[l];
        return;
    }

    int mid = (l + r) >> 1;
    build(ls(x), l, mid), build(rs(x), mid + 1, r);
    pushup(x);
}

void modify(int x, int l, int r) {
    if (mx[x] <= 1)
        return;

    if (l == r) {
        s[x] = mx[x] = sqrt(mx[x]);
        return;
    }

    int mid = (l + r) >> 1;
    modify(ls(x), l, mid), modify(rs(x), mid + 1, r);
    pushup(x);
}

void update(int x, int nl, int nr, int l, int r) {
    if (l <= nl && nr <= r) {
        modify(x, nl, nr);
        return;
    }

    int mid = (nl + nr) >> 1;

    if (l <= mid)
        update(ls(x), nl, mid, l, r);

    if (r > mid)
        update(rs(x), mid + 1, nr, l, r);

    pushup(x);
}

ll query(int x, int nl, int nr, int l, int r) {
    if (l <= nl && nr <= r)
        return s[x];

    int mid = (nl + nr) >> 1;

    if (r <= mid)
        return query(ls(x), nl, mid, l, r);
    else if (l > mid)
        return query(rs(x), mid + 1, nr, l, r);
    else
        return query(ls(x), nl, mid, l, r) + query(rs(x), mid + 1, nr, l, r);
}
} // namespace SMT

signed main() {
    scanf("%d", &n);

    for (int i = 1; i <= n; ++i)
        scanf("%lld", a + i);

    SMT::build(1, 1, n);
    scanf("%d", &m);

    while (m--) {
        int op, l, r;
        scanf("%d%d%d", &op, &l, &r);

        if (l > r)
            swap(l, r);

        if (op)
            printf("%lld\n", SMT::query(1, 1, n, l, r));
        else
            SMT::update(1, 1, n, l, r);
    }

    return 0;
}

LOJ6029. 「雅礼集训 2017 Day1」市场

区间加+区间整除+查询区间最小值+查询区间和。

\(n, m \le 10^5\)

定义线段树上一个点的势能为区间内的极差。

一次整除 \(d\) 操作显然可以至少让极差除以 \(d\) ,所以对于一个点只需要 \(O(\log V)\) 次操作就可以将势能变为 \(0\)

任何整体操作不会影响势能大小,而一次非整体的任意修改操作(如 pushup )会对该节点的势能产生无法预料的变化。一次区间操作最多产生 \(O(\log n)\) 次非整体修改操作,进而带来 \(O(\log n \log V)\) 的额外代价。

总时间复杂度 \(O(n \log n + m \log n \log V)\)

实现时有一些细节,如极差为 \(1\) 是可能出现区间整除操作不改变极差的情况,写法上只要写成最大值和最小值变化量相等时就改为加法标记即可。

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 1e5 + 7;

int a[N];

int n, m;

namespace SMT {
ll s[N << 2], mn[N << 2], mx[N << 2], tag[N << 2];

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

inline void pushup(int x) {
    s[x] = s[ls(x)] + s[rs(x)];
    mn[x] = min(mn[ls(x)], mn[rs(x)]);
    mx[x] = max(mx[ls(x)], mx[rs(x)]);
}

inline void spread(int x, int l, int r, ll k) {
    s[x] += k * (r - l + 1), mn[x] += k, mx[x] += k, tag[x] += k;
}

inline void pushdown(int x, int l, int r) {
    int mid = (l + r) >> 1;

    if (tag[x])
        spread(ls(x), l, mid, tag[x]), spread(rs(x), mid + 1, r, tag[x]), tag[x] = 0;
}

void build(int x, int l, int r) {
    if (l == r) {
        s[x] = mn[x] = mx[x] = a[l];
        return;
    }

    int mid = (l + r) >> 1;
    build(ls(x), l, mid), build(rs(x), mid + 1, r);
    pushup(x);
}

void update(int x, int nl, int nr, int l, int r, int k) {
    if (l <= nl && nr <= r) {
        spread(x, nl, nr, k);
        return;
    }

    pushdown(x, nl, nr);
    int mid = (nl + nr) >> 1;

    if (l <= mid)
        update(ls(x), nl, mid, l, r, k);

    if (r > mid)
        update(rs(x), mid + 1, nr, l, r, k);

    pushup(x);
}

void modify(int x, int nl, int nr, int l, int r, int k) {
    if (l <= nl && nr <= r) {
        auto calc = [](ll x, ll k) {
            return x >= 0 ? x / k - x : (x + 1) / k - 1 - x;
        };

        if (calc(mn[x], k) == calc(mx[x], k)) {
            spread(x, nl, nr, calc(mn[x], k));
            return;
        }
    }

    pushdown(x, nl, nr);
    int mid = (nl + nr) >> 1;

    if (l <= mid)
        modify(ls(x), nl, mid, l, r, k);

    if (r > mid)
        modify(rs(x), mid + 1, nr, l, r, k);

    pushup(x);
}

ll querysum(int x, int nl, int nr, int l, int r) {
    if (l <= nl && nr <= r)
        return s[x];

    pushdown(x, nl, nr);
    int mid = (nl + nr) >> 1;

    if (r <= mid)
        return querysum(ls(x), nl, mid, l, r);
    else if (l > mid)
        return querysum(rs(x), mid + 1, nr, l, r);
    else
        return querysum(ls(x), nl, mid, l, r) + querysum(rs(x), mid + 1, nr, l, r);
}

ll querymin(int x, int nl, int nr, int l, int r) {
    if (l <= nl && nr <= r)
        return mn[x];

    pushdown(x, nl, nr);
    int mid = (nl + nr) >> 1;

    if (r <= mid)
        return querymin(ls(x), nl, mid, l, r);
    else if (l > mid)
        return querymin(rs(x), mid + 1, nr, l, r);
    else
        return min(querymin(ls(x), nl, mid, l, r), querymin(rs(x), mid + 1, nr, l, r));
}
} // namespace SMT

signed main() {
    scanf("%d%d", &n, &m);

    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i);

    SMT::build(1, 1, n);

    while (m--) {
        int op;
        scanf("%d", &op);

        if (op == 1) {
            int l, r, k;
            scanf("%d%d%d", &l, &r, &k);
            SMT::update(1, 1, n, l + 1, r + 1, k);
        } else if (op == 2) {
            int l, r, k;
            scanf("%d%d%d", &l, &r, &k);
            SMT::modify(1, 1, n, l + 1, r + 1, k);
        } else if (op == 3) {
            int l, r;
            scanf("%d%d", &l, &r);
            printf("%lld\n", SMT::querymin(1, 1, n, l + 1, r + 1));
        } else {
            int l, r;
            scanf("%d%d", &l, &r);
            printf("%lld\n", SMT::querysum(1, 1, n, l + 1, r + 1));
        }
    }

    return 0;
}

CF679E Bear and Bad Powers of 42

给出 \(a_{1 \sim n}\)\(m\) 次操作:

  • 1 x :查询 \(a_x\)
  • 2 l r k :将 \(a_{l \sim r}\) 赋值为 \(k\) ,保证 \(k\) 不为 \(42\) 的次幂。
  • 3 l r k :不断将 \(a_{l \sim r}\) 加上 \(k\) 直到 \(a_{l \sim r}\) 中不存在 \(42\) 的次幂为止。

\(n, m \le 10^5\)\(a_i, k \le 10^9\)

显然操作三并不能不停操作下去,因为操作有限次之后一定能找到两个相邻的 \(42\) 次幂使得 \(a_{l \sim r}\) 都夹在中间。事实上只要预处理到 \(42^{11}\) 就足够,对于三操作直接暴力修改即可,下面考虑如何判断区间内是否存在 \(42\) 的次幂。

考虑在线段树上维护区间内与下一个 \(42\) 次幂差值的最小值,则操作二可以暴力赋值,操作三在最小值 \(\ge x\) 或存在赋值标记时打标记,否则保留递归。

定义势能函数为区间所有颜色段上方 \(42\) 次幂的个数和,那么一次二操作会增加 \(O(\log n \log_{42} V)\) 的势能,一次三操作会减少 \(O(1)\) 的势能,总时间复杂度 \(O(n \log n \log_{42} V)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 1e5 + 7, L = 12;

ll pw[L];
int a[N];

int n, m;

namespace SMT {
ll mn[N << 2], gap[N << 2], pid[N << 2], tag[N << 2];

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

inline void pushup(int x) {
    mn[x] = min(mn[ls(x)], mn[rs(x)]);
}

inline void spread(int x, ll gk, ll ik, ll k) {
    if (ik)
        mn[x] = gap[x] = gk, pid[x] = ik, tag[x] = 0;
    
    if (k) {
        if (!pid[x])
            mn[x] -= k, tag[x] += k;
        else {
            gap[x] -= k;

            while (gap[x] < 0)
                gap[x] += pw[pid[x] + 1] - pw[pid[x]], ++pid[x];

            mn[x] = gap[x];
        }
    }
}

inline void pushdown(int x) {
    spread(ls(x), gap[x], pid[x], tag[x]);
    spread(rs(x), gap[x], pid[x], tag[x]);
    gap[x] = pid[x] = tag[x] = 0;
}

void build(int x, int l, int r) {
    if (l == r) {
        pid[x] = lower_bound(pw, pw + L, a[l]) - pw;
        mn[x] = gap[x] = pw[pid[x]] - a[l];
        return;
    }

    int mid = (l + r) >> 1;
    build(ls(x), l, mid), build(rs(x), mid + 1, r);
    pushup(x);
}

void cover(int x, int nl, int nr, int l, int r, ll gk, int ik) {
    if (l <= nl && nr <= r) {
        spread(x, gk, ik, 0);
        return;
    }

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (l <= mid)
        cover(ls(x), nl, mid, l, r, gk, ik);

    if (r > mid)
        cover(rs(x), mid + 1, nr, l, r, gk, ik);

    pushup(x);
}

void update(int x, int nl, int nr, int l, int r, int k) {
    if (l <= nl && nr <= r && (mn[x] >= k || pid[x])) {
        spread(x, 0, 0, k);
        return;
    }

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (l <= mid)
        update(ls(x), nl, mid, l, r, k);

    if (r > mid)
        update(rs(x), mid + 1, nr, l, r, k);

    pushup(x);
}

ll query(int x, int nl, int nr, int p) {
    if (nl == nr)
        return pw[pid[x]] - mn[x];

    pushdown(x);
    int mid = (nl + nr) >> 1;
    return p <= mid ? query(ls(x), nl, mid, p) : query(rs(x), mid + 1, nr, p);
}
} // namespace SMT

signed main() {
    pw[0] = 1;

    for (int i = 1; i < L; ++i)
        pw[i] = pw[i - 1] * 42;

    scanf("%d%d", &n, &m);

    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i);

    SMT::build(1, 1, n);

    while (m--) {
        int op;
        scanf("%d", &op);

        if (op == 1) {
            int x;
            scanf("%d", &x);
            printf("%lld\n", SMT::query(1, 1, n, x));
        } else if (op == 2) {
            int l, r, k;
            scanf("%d%d%d", &l, &r, &k);
            int id = lower_bound(pw, pw + L, k) - pw;
            SMT::cover(1, 1, n, l, r, pw[id] - k, id);
        } else {
            int l, r, k;
            scanf("%d%d%d", &l, &r, &k);

            do
                SMT::update(1, 1, n, l, r, k);
            while (!SMT::mn[1]);
        }
    }

    return 0;
}

P4891 序列

给出两个序列 \(a_{1 \sim n}\)\(b_{1 \sim n}\) ,记 \(c_{1 \sim n}\)\(a\) 的前缀最大值,\(m\) 次对 \(a\)\(b\) 的单点修改操作(保证越改越大),每次修改后求 \(\prod_{i = 1}^n \min(b_i, c_i) \pmod{10^9 + 7}\)

\(n, q \le 10^5\)

首先将 \(a\) 上的单点操作转化为 \(c\) 上的后缀操作,考虑什么时候区间可以整体处理:

  • \(\min c > \max b\) :对区间答案不会造成影响,直接打上区间覆盖的标记。
  • \(k \le \min b\) :会对区间答案造成影响,打上区间覆盖标记、处理贡献后返回。
  • 其他情况:暴力递归。

定义一个区间的势能为 \(c_i < b_i\) 的个数,则一次递归到叶子的单点修改 \(c\) 则会使势能 \(-1\) ,单点修改 \(b\) 有可能会使势能 \(+1\) ,因此可以分析得到时间复杂度 \(O((n + q) \log^2 n)\) (多一个 \(\log\) 是快速幂的复杂度)。

实现是需要分开维护取到 \(b\)\(c\) 的贡献。

#include <bits/stdc++.h>
using namespace std;
const int inf = 0x3f3f3f3f;
const int Mod = 1e9 + 7;
const int N = 1e5 + 7;

int a[N], b[N], c[N];

int n, m;

inline int mi(int c, int b) {
    int res = 1;
    
    for (; b; b >>= 1, c = 1ll * c * c % Mod)
        if (b & 1)
            res = 1ll * res * c % Mod;
    
    return res;
}

namespace SMT {
int mxa[N << 2], tag[N << 2], mulb[N << 2], cntb[N << 2], mnb[N << 2], mulc[N << 2], cntc[N << 2];

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

inline void pushup(int x) {
    mxa[x] = max(mxa[ls(x)], mxa[rs(x)]);

    mulb[x] = 1ll * mulb[ls(x)] * mulb[rs(x)] % Mod;
    cntb[x] = cntb[ls(x)] + cntb[rs(x)];
    mnb[x] = min(mnb[ls(x)], mnb[rs(x)]);

    mulc[x] = 1ll * mulc[ls(x)] * mulc[rs(x)] % Mod;
    cntc[x] = cntc[ls(x)] + cntc[rs(x)];
}

inline void spread(int x, int k) {
    mulc[x] = mi(k, cntc[x]), tag[x] = k;
}

inline void pushdown(int x) {
    if (~tag[x])
        spread(ls(x), tag[x]), spread(rs(x), tag[x]), tag[x] = -1;
}

void build(int x, int l, int r) {
    tag[x] = -1;

    if (l == r) {
        mxa[x] = a[l];

        if (c[l] < b[l]) {
            mulc[x] = c[l], cntc[x] = 1;
            mulb[x] = 1, cntb[x] = 0, mnb[x] = b[l];
        } else {
            mulc[x] = 1, cntc[x] = 0;
            mulb[x] = b[l], cntb[x] = 1, mnb[x] = inf;
        }

        return;
    }

    int mid = (l + r) >> 1;
    build(ls(x), l, mid), build(rs(x), mid + 1, r);
    pushup(x);
}

void updateA(int x, int nl, int nr, int pos, int k) {
    if (nl == nr) {
        mxa[x] = max(mxa[x], k);
        return;
    }

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (pos <= mid)
        updateA(ls(x), nl, mid, pos, k);
    else
        updateA(rs(x), mid + 1, nr, pos, k);

    pushup(x);
}

void updateB(int x, int nl, int nr, int pos, int k, int c) {
    if (nl == nr) {
        c = max(c, mxa[x]);

        if (k > c) {
            mulc[x] = c, cntc[x] = 1;
            mulb[x] = 1, cntb[x] = 0, mnb[x] = k;
        } else {
            mulc[x] = 1, cntc[x] = 0;
            mulb[x] = k, cntb[x] = 1, mnb[x] = inf;
        }

        return;
    }

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (pos <= mid)
        updateB(ls(x), nl, mid, pos, k, c);
    else
        updateB(rs(x), mid + 1, nr, pos, k, max(mxa[ls(x)], c));

    pushup(x);
}

void updateC(int x, int nl, int nr, int l, int r, int k) {
    if (l <= nl && nr <= r) {
        if (mnb[x] > k) {
            spread(x, k);
            return;
        }

        if (nl == nr) {
            mulc[x] = 1, cntc[x] = 0;
            mulb[x] = b[nl], cntb[x] = 1, mnb[x] = inf;
            return;
        }
    }

    int mid = (nl + nr) >> 1;

    if (l <= mid)
        updateC(ls(x), nl, mid, l, r, k);

    if (r > mid)
        updateC(rs(x), mid + 1, nr, l, r, k);

    pushup(x);
}

int search(int x, int l, int r, int k, int c) {
    if (l == r)
        return l;

    pushdown(x);
    int mid = (l + r) >> 1;

    if (max(c, mxa[ls(x)]) <= k)
        return search(rs(x), mid + 1, r, k, max(c, mxa[ls(x)]));
    else
        return search(ls(x), l, mid, k, c);
}

int queryC(int x, int nl, int nr, int pos) {
    if (nl == nr)
        return mxa[x];

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (pos <= mid)
        return queryC(ls(x), nl, mid, pos);
    else
        return max(mxa[ls(x)], queryC(rs(x), mid + 1, nr, pos));
}
} // namespace SMT

signed main() {
    scanf("%d%d", &n, &m);

    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i), c[i] = max(c[i - 1], a[i]);

    for (int i = 1; i <= n; ++i)
        scanf("%d", b + i);

    SMT::build(1, 1, n);

    while (m--) {
        int op, x, k;
        scanf("%d%d%d", &op, &x, &k);

        if (op)
            SMT::updateB(1, 1, n, x, b[x] = k, -1);
        else {
            int c = SMT::queryC(1, 1, n, x);
            SMT::updateA(1, 1, n, x, k);

            if (k > c)
                SMT::updateC(1, 1, n, x, SMT::search(1, 1, n, k, -1), k);
        }

        printf("%d\n", 1ll * SMT::mulb[1] * SMT::mulc[1] % Mod);
    }

    return 0;
}

区间最值操作

P10639 BZOJ4695 最佳女选手

给出序列 \(a_{1 \sim n}\)\(m\) 次操作:区间加、区间取 \(\min / \max\) 、求区间和、求区间 \(\min / \max\)

\(n, m \le 5 \times 10^5\)

考虑对线段树上的每个点维护:

  • 区间信息:
    • 最大值、次大值、最大值个数;
    • 最小值、次小值、最小值个数;
    • 区间和。
  • 区间标记:区间加、区间 \(\max\) 、区间 \(\min\) ,钦定区间加标记优先级最高。

需要特殊处理只有一个数或两个数的时侯可能发生数集重合的情况,细节较多,时间复杂度 \(O(m \log^2 n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int inf = 0x3f3f3f3f;
const int N = 5e5 + 7;

int a[N];

int n, m;

namespace SMT {
struct Tag {
    ll ad;
    int mn, mx;

    inline void spread_add(ll k) {
        if (mx != -inf)
            mx += k;

        if (mn != inf)
            mn += k;

        ad += k;
    }

    inline void spread_min(int k) {
        if (mx > k)
            mx = k;

        mn = k;
    }

    inline void spread_max(int k) {
        if (mn < k)
            mn = k;

        mx = k;
    }
} tag[N << 2];

struct Node {
    ll sum;
    int len, mn, mncnt, secmn, mx, mxcnt, secmx;

    inline friend Node operator + (Node a, Node b) {
        Node c;
        c.len = a.len + b.len, c.sum = a.sum + b.sum;

        if (a.mx == b.mx) {
            c.mx = a.mx, c.mxcnt = a.mxcnt + b.mxcnt;
            c.secmx = max(a.secmx, b.secmx);
        } else if (a.mx > b.mx) {
            c.mx = a.mx, c.mxcnt = a.mxcnt;
            c.secmx = max(a.secmx, b.mx);
        } else {
            c.mx = b.mx, c.mxcnt = b.mxcnt;
            c.secmx = max(a.mx, b.secmx);
        }

        if (a.mn == b.mn) {
            c.mn = a.mn, c.mncnt = a.mncnt + b.mncnt;
            c.secmn = min(a.secmn, b.secmn);
        } else if (a.mn < b.mn) {
            c.mn = a.mn, c.mncnt = a.mncnt;
            c.secmn = min(a.secmn, b.mn);
        } else {
            c.mn = b.mn, c.mncnt = b.mncnt;
            c.secmn = min(a.mn, b.secmn);
        }

        return c;
    }

    inline void spread_add(ll k) {
        sum += k * len, mx += k, mn += k;

        if (secmx != -inf)
            secmx += k;

        if (secmn != inf)
            secmn += k;
    }

    inline void spread_min(int k) {
        sum += 1ll * mxcnt * (k - mx);

        if (secmn == mx)
            secmn = k;

        if (mn == mx)
            mn = k;
        
        mx = k;
    }

    inline void spread_max(int k) {
        sum += 1ll * mncnt * (k - mn);

        if (secmx == mn)
            secmx = k;

        if (mx == mn)
            mx = k;

        mn = k;
    }
} nd[N << 2];

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

inline void spread(int x, Tag k) {
    if (k.ad)
        nd[x].spread_add(k.ad), tag[x].spread_add(k.ad);

    if (k.mn < nd[x].mx)
        nd[x].spread_min(k.mn), tag[x].spread_min(k.mn);

    if (k.mx > nd[x].mn)
        nd[x].spread_max(k.mx), tag[x].spread_max(k.mx);
}

inline void pushdown(int x) {
    spread(ls(x), tag[x]), spread(rs(x), tag[x]), tag[x] = (Tag){0, inf, -inf};
}

void build(int x, int l, int r) {
    tag[x] = (Tag){0, inf, -inf};

    if (l == r) {
        nd[x].sum = nd[x].mx = nd[x].mn = a[l];
        nd[x].len = nd[x].mxcnt = nd[x].mncnt = 1;
        nd[x].secmx = -inf, nd[x].secmn = inf;
        return;
    }

    int mid = (l + r) >> 1;
    build(ls(x), l, mid), build(rs(x), mid + 1, r);
    nd[x] = nd[ls(x)] + nd[rs(x)];
}

void updatesum(int x, int nl, int nr, int l, int r, int k) {
    if (l <= nl && nr <= r) {
        nd[x].spread_add(k), tag[x].spread_add(k);
        return;
    }

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (l <= mid)
        updatesum(ls(x), nl, mid, l, r, k);

    if (r > mid)
        updatesum(rs(x), mid + 1, nr, l, r, k);

    nd[x] = nd[ls(x)] + nd[rs(x)];
}

void updatemin(int x, int nl, int nr, int l, int r, int k) {
    if (k >= nd[x].mx)
        return;

    if (l <= nl && nr <= r && nd[x].secmx < k) {
        nd[x].spread_min(k), tag[x].spread_min(k);
        return;
    }

    if (nl == nr)
        return;

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (l <= mid)
        updatemin(ls(x), nl, mid, l, r, k);

    if (r > mid)
        updatemin(rs(x), mid + 1, nr, l, r, k);

    nd[x] = nd[ls(x)] + nd[rs(x)];
}

void updatemax(int x, int nl, int nr, int l, int r, int k) {
    if (k <= nd[x].mn)
        return;

    if (l <= nl && nr <= r && nd[x].secmn > k) {
        nd[x].spread_max(k), tag[x].spread_max(k);
        return;
    }

    if (nl == nr)
        return;

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (l <= mid)
        updatemax(ls(x), nl, mid, l, r, k);

    if (r > mid)
        updatemax(rs(x), mid + 1, nr, l, r, k);

    nd[x] = nd[ls(x)] + nd[rs(x)];
}

Node query(int x, int nl, int nr, int l, int r) {
    if (l <= nl && nr <= r)
        return nd[x];

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (r <= mid)
        return query(ls(x), nl, mid, l, r);
    else if (l > mid)
        return query(rs(x), mid + 1, nr, l, r);
    else
        return query(ls(x), nl, mid, l, r) + query(rs(x), mid + 1, nr, l, r);
}
} // namespace SMT

signed main() {
    scanf("%d", &n);

    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i);

    SMT::build(1, 1, n);
    scanf("%d", &m);

    while (m--) {
        int op, l, r;
        scanf("%d%d%d", &op, &l, &r);

        if (op == 1) {
            int x;
            scanf("%d", &x);
            SMT::updatesum(1, 1, n, l, r, x);
        } else if (op == 2) {
            int x;
            scanf("%d", &x);
            SMT::updatemax(1, 1, n, l, r, x);
        } else if (op == 3) {
            int x;
            scanf("%d", &x);
            SMT::updatemin(1, 1, n, l, r, x);
        } else if (op == 4)
            printf("%lld\n", SMT::query(1, 1, n, l, r).sum);
        else if (op == 5)
            printf("%d\n", SMT::query(1, 1, n, l, r).mx);
        else
            printf("%d\n", SMT::query(1, 1, n, l, r).mn);
    }

    return 0;
}

UOJ515. 【UR #19】前进四

给出序列 \(a_{1 \sim n}\)\(m\) 次操作,操作有:

  • \(a_x\) 修改为 \(k\)
  • \(a_{x \sim n}\) 的不同后缀最小值个数。

\(n, m \le 10^6\)

直接上兔队线段树可以做到 \(O((n + m) \log^2 n)\) ,难以通过。

考虑离线从后往前对序列扫描线,用线段树维护每个时间点的后缀最小值。假设现在扫到了 \(i\)

  • 修改:其会影响从修改的时刻开始,到该位置下一次修改的时刻为止的时间段,将这个时间段的后缀最小值都对 \(k\)\(\min\)
  • 询问:答案为该询问时刻被取 \(\min\) 的次数,下传标记时顺便维护即可。

时间复杂度 \(O((n + m) \log n)\)

#include <bits/stdc++.h>
using namespace std;
const int inf = 0x3f3f3f3f;
const int N = 1e6 + 7;

vector<pair<int, int> > upd[N];
vector<int> qry[N];

int ans[N];

int n, m, cntq;

namespace SMT {
int mx[N << 2], sec[N << 2], tag[N << 2];

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

inline void pushup(int x) {
    mx[x] = max(mx[ls(x)], mx[rs(x)]);

    if (mx[ls(x)] == mx[rs(x)])
        sec[x] = max(sec[ls(x)], sec[rs(x)]);
    else if (mx[ls(x)] > mx[rs(x)])
        sec[x] = max(sec[ls(x)], mx[rs(x)]);
    else
        sec[x] = max(mx[ls(x)], sec[rs(x)]);
}

inline void spread(int x, int k, int t) {
    mx[x] = k, tag[x] += t;
}

inline void pushdown(int x) {
    if (mx[ls(x)] > mx[x])
        spread(ls(x), mx[x], tag[x]);

    if (mx[rs(x)] > mx[x])
        spread(rs(x), mx[x], tag[x]);

    tag[x] = 0;
}

void build(int x, int l, int r) {
    mx[x] = inf;

    if (l == r)
        return;

    int mid = (l + r) >> 1;
    build(ls(x), l, mid), build(rs(x), mid + 1, r);
}

void update(int x, int nl, int nr, int l, int r, int k) {
    if (k >= mx[x])
        return;

    if (l <= nl && nr <= r && k > sec[x]) {
        spread(x, k, 1);
        return;
    }

    if (nl == nr)
        return;

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (l <= mid)
        update(ls(x), nl, mid, l, r, k);

    if (r > mid)
        update(rs(x), mid + 1, nr, l, r, k);

    pushup(x);
}

int query(int x, int nl, int nr, int p) {
    if (nl == nr)
        return tag[x];

    pushdown(x);
    int mid = (nl + nr) >> 1;
    return p <= mid ? query(ls(x), nl, mid, p) : query(rs(x), mid + 1, nr, p);
}
} // namespace SMT

signed main() {
    scanf("%d%d", &n, &m);

    for (int i = 1; i <= n; ++i) {
        int k;
        scanf("%d", &k);
        upd[i].emplace_back(0, k);
    }

    while (m--) {
        int op, x;
        scanf("%d%d", &op, &x);

        if (op == 1) {
            if (!upd[x].empty() && upd[x].back().first == cntq)
                upd[x].pop_back();

            int k;
            scanf("%d", &k);
            upd[x].emplace_back(cntq, k);
        } else
            qry[x].emplace_back(cntq++);
    }

    SMT::build(1, 0, cntq);

    for (int i = n; i; --i) {
        reverse(upd[i].begin(), upd[i].end());
        int lsttim = cntq;

        for (auto it : upd[i])
            SMT::update(1, 0, cntq, it.first, lsttim, it.second), lsttim = it.first - 1;

        for (int it : qry[i])
            ans[it] = SMT::query(1, 0, cntq, it);
    }

    for (int i = 0; i < cntq; ++i)
        printf("%d\n", ans[i]);

    return 0;
}

CF793F Julia the snail

有一个长为 \(n\) 的杆,上面有 \(m\) 条绳子,每条绳子可以让蜗牛从 \(l_i\) 爬到 \(r_i\)(中途不能离开),保证 \(r_i\) 各不相同。蜗牛也可以自然下落。

\(q\) 次询问从 \(x\) 出发、途中高度 \(\in [x, y]\) 时最高能爬到的位置。

\(n, m, q \le 10^5\)

考虑离线对 \(y\) 扫描线,维护每个 \(x\) 的答案。

\(f(x)\) 表示高度 \(\ge x\) 时的答案,则每次 \(y \to y + 1\) 时会加入若干条右端点为 \(y\) 的绳子 \([l, y]\) 。考虑这些绳子的影响,对于 \(x \in [1, l]\) ,若 \(f(x) \ge l\) ,则令 \(f(x) = y\)

不难用势能线段树维护这个过程,区间维护最大值和次大值,即可做到 \(O((n + m) \log n + q)\)

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 7;

vector<pair<int, int> > qry[N];
vector<int> upd[N];

int ans[N];

int n, m, q;

namespace SMT {
int mx[N << 2], sec[N << 2], tag[N << 2];

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

inline void pushup(int x) {
    mx[x] = max(mx[ls(x)], mx[rs(x)]);
    sec[x] = max(sec[ls(x)], sec[rs(x)]);

    if (mx[ls(x)] != mx[x])
        sec[x] = max(sec[x], mx[ls(x)]);

    if (mx[rs(x)] != mx[x])
        sec[x] = max(sec[x], mx[rs(x)]);
}

inline void spread(int x, int k) {
    mx[x] += k, tag[x] += k;
}

inline void pushdown(int x) {
    if (tag[x]) {
        if (mx[ls(x)] + tag[x] == mx[x])
            spread(ls(x), tag[x]);

        if (mx[rs(x)] + tag[x] == mx[x])
            spread(rs(x), tag[x]);

        tag[x] = 0;
    }
}

void build(int x, int l, int r) {
    mx[x] = r;

    if (l == r)
        return;

    int mid = (l + r) >> 1;
    build(ls(x), l, mid), build(rs(x), mid + 1, r);
}

void update(int x, int nl, int nr, int l, int r, int lim, int k) {
    if (mx[x] < lim)
        return;
    
    if (l <= nl && nr <= r && sec[x] < lim) {
        spread(x, k - mx[x]);
        return;
    }

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (l <= mid)
        update(ls(x), nl, mid, l, r, lim, k);

    if (r > mid)
        update(rs(x), mid + 1, nr, l, r, lim, k);

    pushup(x);
}

int query(int x, int nl, int nr, int p) {
    if (nl == nr)
        return mx[x];

    pushdown(x);
    int mid = (nl + nr) >> 1;
    return p <= mid ? query(ls(x), nl, mid, p) : query(rs(x), mid + 1, nr, p);
}
} // namespace SMT

signed main() {
    scanf("%d%d", &n, &m);

    for (int i = 1; i <= m; ++i) {
        int l, r;
        scanf("%d%d", &l, &r);
        upd[r].emplace_back(l);
    }

    scanf("%d", &q);

    for (int i = 1; i <= q; ++i) {
        int l, r;
        scanf("%d%d", &l, &r);
        qry[r].emplace_back(l, i);
    }

    SMT::build(1, 1, n);

    for (int i = 1; i <= n; ++i) {
        for (int it : upd[i])
            SMT::update(1, 1, n, 1, it, it, i);

        for (auto it : qry[i])
            ans[it.second] = SMT::query(1, 1, n, it.first);
    }

    for (int i = 1; i <= q; ++i)
        printf("%d\n", ans[i]);

    return 0;
}

历史值相关

P6242 【模板】线段树 3

给定 \(a_{1 \sim n}\)\(m\) 次操作:区间加、区间取 \(\min\) 、区间和、区间最大值、区间历史最大值。

\(n, m \le 5 \times 10^5\)

考虑划分数域将区间最值操作转化为区间加减操作,线段树上的每个点需要维护四种标记:

  • 最大值的加减标记 \(tag1\)
  • 最大值历史最大的加减标记 \(histag1\)
  • 非最大值的加减标记 \(tag2\)
  • 非最大值历史最大的加减标记 \(histag2\)

其实 \(histag\) 就是 \(tag\) 下放前最大的 \(tag\)

前两个标记是只修改最大值的,所以下传时要判断当前节点是否包含区间最大值。

势能分析可以得到总时间复杂度为 \(O((n + m) \log^2 n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int inf = 0x3f3f3f3f;
const int N = 5e5 + 7;

int a[N];

int n, m;

namespace SMT {
ll s[N << 2];
int len[N << 2], mx[N << 2], hismx[N << 2], sec[N << 2], cnt[N << 2];
int tag1[N << 2], tag2[N << 2], histag1[N << 2], histag2[N << 2];

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

inline void pushup(int x) {
    s[x] = s[ls(x)] + s[rs(x)];
    mx[x] = max(mx[ls(x)], mx[rs(x)]);
    hismx[x] = max(hismx[ls(x)], hismx[rs(x)]);

    if (mx[ls(x)] == mx[rs(x)])
        cnt[x] = cnt[ls(x)] + cnt[rs(x)], sec[x] = max(sec[ls(x)], sec[rs(x)]);
    else if (mx[ls(x)] > mx[rs(x)])
        cnt[x] = cnt[ls(x)], sec[x] = max(sec[ls(x)], mx[rs(x)]);
    else
        cnt[x] = cnt[rs(x)], sec[x] = max(mx[ls(x)], sec[rs(x)]);
}

inline void spread(int x, int k1, int k2, int hisk1, int hisk2) {
    s[x] += 1ll * k1 * cnt[x] + 1ll * k2 * (len[x] - cnt[x]);
    hismx[x] = max(hismx[x], mx[x] + hisk1), mx[x] += k1;

    if (sec[x] != -inf)
        sec[x] += k2;

    histag1[x] = max(histag1[x], tag1[x] + hisk1), tag1[x] += k1;
    histag2[x] = max(histag2[x], tag2[x] + hisk2), tag2[x] += k2;
}

inline void pushdown(int x) {
    int mxval = max(mx[ls(x)], mx[rs(x)]);

    if (mx[ls(x)] == mxval)
        spread(ls(x), tag1[x], tag2[x], histag1[x], histag2[x]);
    else
        spread(ls(x), tag2[x], tag2[x], histag2[x], histag2[x]);

    if (mx[rs(x)] == mxval)
        spread(rs(x), tag1[x], tag2[x], histag1[x], histag2[x]);
    else
        spread(rs(x), tag2[x], tag2[x], histag2[x], histag2[x]);

    tag1[x] = tag2[x] = histag1[x] = histag2[x] = 0;
}

void build(int x, int l, int r) {
    len[x] = r - l + 1;

    if (l == r) {
        s[x] = mx[x] = hismx[x] = a[l];
        cnt[x] = 1, sec[x] = -inf;
        return;
    }

    int mid = (l + r) >> 1;
    build(ls(x), l, mid), build(rs(x), mid + 1, r);
    pushup(x);
}

void update(int x, int nl, int nr, int l, int r, int k) {
    if (l <= nl && nr <= r) {
        spread(x, k, k, k, k);
        return;
    }

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (l <= mid)
        update(ls(x), nl, mid, l, r, k);

    if (r > mid)
        update(rs(x), mid + 1, nr, l, r, k);

    pushup(x);
}

void modify(int x, int nl, int nr, int l, int r, int k) {
    if (k >= mx[x])
        return;

    if (l <= nl && nr <= r && k > sec[x]) {
        spread(x, k - mx[x], 0, k - mx[x], 0);
        return;
    }

    if (nl == nr)
        return;

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (l <= mid)
        modify(ls(x), nl, mid, l, r, k);

    if (r > mid)
        modify(rs(x), mid + 1, nr, l, r, k);

    pushup(x);
}

ll querysum(int x, int nl, int nr, int l, int r) {
    if (l <= nl && nr <= r)
        return s[x];

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (r <= mid)
        return querysum(ls(x), nl, mid, l, r);
    else if (l > mid)
        return querysum(rs(x), mid + 1, nr, l, r);
    else
        return querysum(ls(x), nl, mid, l, r) + querysum(rs(x), mid + 1, nr, l, r);
}

int querymax(int x, int nl, int nr, int l, int r) {
    if (l <= nl && nr <= r)
        return mx[x];

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (r <= mid)
        return querymax(ls(x), nl, mid, l, r);
    else if (l > mid)
        return querymax(rs(x), mid + 1, nr, l, r);
    else
        return max(querymax(ls(x), nl, mid, l, r), querymax(rs(x), mid + 1, nr, l, r));
}

int queryhismax(int x, int nl, int nr, int l, int r) {
    if (l <= nl && nr <= r)
        return hismx[x];

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (r <= mid)
        return queryhismax(ls(x), nl, mid, l, r);
    else if (l > mid)
        return queryhismax(rs(x), mid + 1, nr, l, r);
    else
        return max(queryhismax(ls(x), nl, mid, l, r), queryhismax(rs(x), mid + 1, nr, l, r));
}
} // namespace SMT

signed main() {
    scanf("%d%d", &n, &m);

    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i);

    SMT::build(1, 1, n);

    while (m--) {
        int op, l, r;
        scanf("%d%d%d", &op, &l, &r);

        if (op == 1) {
            int k;
            scanf("%d", &k);
            SMT::update(1, 1, n, l, r, k);
        } else if (op == 2) {
            int k;
            scanf("%d", &k);
            SMT::modify(1, 1, n, l, r, k);
        } else if (op == 3)
            printf("%lld\n", SMT::querysum(1, 1, n, l, r));
        else if (op == 4)
            printf("%d\n", SMT::querymax(1, 1, n, l, r));
        else
            printf("%d\n", SMT::queryhismax(1, 1, n, l, r));
    }

    return 0;
}

P4314 CPU 监控

给定 \(a_{1 \sim n}\)\(m\) 次操作:区间加 + 区间赋值 + 查询区间最大值 + 查询区间历史最大值。

\(n, m \le 10^5\)

维护二元组标记 \((add, cov)\) ,表示当前区间先加上 \(add\) 再覆盖为 \(cov\) ,同时记录 \((hisadd, hiscov)\) 表示上一次下传后的历史最大标记。

有一些细节需要处理,时间复杂度 \(O(m \log n)\)

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 7;

int a[N];

int n, m;

namespace SMT {
int mx[N << 2], hismx[N << 2], tag[N << 2], histag[N << 2], cov[N << 2], hiscov[N << 2];
bool iscov[N << 2];

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

inline void pushup(int x) {
    mx[x] = max(mx[ls(x)], mx[rs(x)]);
    hismx[x] = max(hismx[ls(x)], hismx[rs(x)]);
}

inline void spread_add(int x, int k, int hisk) {
    hismx[x] = max(hismx[x], mx[x] + hisk), mx[x] += k;

    if (iscov[x])
        hiscov[x] = max(hiscov[x], cov[x] + hisk), cov[x] += k;
    else
        histag[x] = max(histag[x], tag[x] + hisk), tag[x] += k;
}

inline void spread_cov(int x, int k, int hisk) {
    hismx[x] = max(hismx[x], hisk), mx[x] = k;
    hiscov[x] = max(hiscov[x], hisk), cov[x] = k, iscov[x] = true;
}

inline void pushdown(int x) {
    spread_add(ls(x), tag[x], histag[x]);
    spread_add(rs(x), tag[x], histag[x]);
    tag[x] = histag[x] = 0;

    if (iscov[x]) {
        spread_cov(ls(x), cov[x], hiscov[x]);
        spread_cov(rs(x), cov[x], hiscov[x]);
        cov[x] = hiscov[x] = -INT_MAX, iscov[x] = false;
    }
}

void build(int x, int l, int r) {
    if (l == r) {
        mx[x] = hismx[x] = a[l];
        return;
    }

    int mid = (l + r) >> 1;
    build(ls(x), l, mid), build(rs(x), mid + 1, r);
    pushup(x);
}

void update(int x, int nl, int nr, int l, int r, int k) {
    if (l <= nl && nr <= r) {
        spread_add(x, k, k);
        return;
    }

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (l <= mid)
        update(ls(x), nl, mid, l, r, k);

    if (r > mid)
        update(rs(x), mid + 1, nr, l, r, k);

    pushup(x);
}

void cover(int x, int nl, int nr, int l, int r, int k) {
    if (l <= nl && nr <= r) {
        spread_cov(x, k, k);
        return;
    }

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (l <= mid)
        cover(ls(x), nl, mid, l, r, k);

    if (r > mid)
        cover(rs(x), mid + 1, nr, l, r, k);

    pushup(x);
}

int querymax(int x, int nl, int nr, int l, int r) {
    if (l <= nl && nr <= r)
        return mx[x];

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (r <= mid)
        return querymax(ls(x), nl, mid, l, r);
    else if (l > mid)
        return querymax(rs(x), mid + 1, nr, l, r);
    else
        return max(querymax(ls(x), nl, mid, l, r), querymax(rs(x), mid + 1, nr, l, r));
}

int queryhismax(int x, int nl, int nr, int l, int r) {
    if (l <= nl && nr <= r)
        return hismx[x];

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (r <= mid)
        return queryhismax(ls(x), nl, mid, l, r);
    else if (l > mid)
        return queryhismax(rs(x), mid + 1, nr, l, r);
    else
        return max(queryhismax(ls(x), nl, mid, l, r), queryhismax(rs(x), mid + 1, nr, l, r));
}
} // namespace SMT

signed main() {
    cin >> n;

    for (int i = 1; i <= n; ++i)
        cin >> a[i];

    SMT::build(1, 1, n);
    cin >> m;

    while (m--) {
        char op;
        int l, r;
        cin >> op >> l >> r;

        if (op == 'Q')
            printf("%d\n", SMT::querymax(1, 1, n, l, r));
        else if (op == 'A')
            printf("%d\n", SMT::queryhismax(1, 1, n, l, r));
        else if (op == 'P') {
            int k;
            cin >> k;
            SMT::update(1, 1, n, l, r, k);
        } else {
            int k;
            cin >> k;
            SMT::cover(1, 1, n, l, r, k);
        }
    }

    return 0;
}

UOJ164. 【清华集训2015】V

给出序列 \(a_{1 \sim n}\)\(m\) 次操作,操作有:

  • 1 l r k :将区间 \([l, r]\) 加上 \(k\)
  • 2 l r k :将区间 \([l, r]\) 减去 \(k\) 后对 \(0\)\(\max\)
  • 3 l r k :将区间 \([l, r]\) 赋值为 \(k\)
  • 4 x :查询 \(a_x\)
  • 5 x :查询 \(x\) 处的历史最大值。

\(n, m \le 5 \times 10^5\)

定义标记 \((ad, k)\) 表示加上 \(ad\) 后与 \(k\)\(\max\) ,则三种操作可以表示为 \((k, 0)\)\((-k, 0)\)\((-\infty, k)\) ,同时维护当前标记与历史标记。

对于标记的合并,考虑合并两个标记 \((a, b)\)\((c, d)\) ,则 \(x\) 会变为 \(\max(\max(x + a, b) + c, d) = \max(x + a + c, \max(b + c, d))\)

对于与历史标记的取 \(\max\) ,考虑将标记视为一次函数,\(k - ad\) 的左边斜率为 \(0\) ,右边斜率为 \(1\) 。画图可以得知 \(\max((a, b), (c, d)) = (\max(a, c), \max(b, d))\) ,另一个理解是由于这两个函数取 \(\max\) 后一定也是左边斜率为 \(0\) 、右边斜率为 \(1\) 的图象,而自变量很小时显然取 \(\max(b, d)\) 优,自变量很大时显然取 \(\max(a, c)\) 优。

由于都是单点查询, 可以不用维护区间信息,直接递归到叶子时对原序列修改即可。

时间复杂度 \(O(m \log n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const ll inf = 0x3f3f3f3f3f3f3f3f;
const int N = 5e5 + 7;

ll a[N];

int n, m;

namespace SMT {
struct Tag {
    ll ad, k;

    inline Tag() : ad(0), k(-inf) {}

    inline Tag(ll _ad, ll _k) : ad(_ad), k(_k) {}

    inline friend Tag operator + (Tag a, Tag b) {
        return Tag(max(a.ad + b.ad, -inf), max(a.k + b.ad, b.k));
    }

    inline friend Tag operator | (Tag a, Tag b) {
        return Tag(max(a.ad, b.ad), max(a.k, b.k));
    }
} tag[N << 2], histag[N << 2];

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

inline void spread(int x, Tag k, Tag hisk) {
    histag[x] = histag[x] | (tag[x] + hisk), tag[x] = tag[x] + k;
}

inline void pushdown(int x) {
    spread(ls(x), tag[x], histag[x]), spread(rs(x), tag[x], histag[x]), tag[x] = histag[x] = Tag();;
}

void update(int x, int nl, int nr, int l, int r, Tag k) {
    if (l <= nl && nr <= r) {
        spread(x, k, k);
        return;
    }

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (l <= mid)
        update(ls(x), nl, mid, l, r, k);

    if (r > mid)
        update(rs(x), mid + 1, nr, l, r, k);
}

ll querymax(int x, int nl, int nr, int p) {
    if (nl == nr)
        return max(a[p] + tag[x].ad, tag[x].k);

    pushdown(x);
    int mid = (nl + nr) >> 1;
    return p <= mid ? querymax(ls(x), nl, mid, p) : querymax(rs(x), mid + 1, nr, p);
}

ll queryhismax(int x, int nl, int nr, int p) {
    if (nl == nr)
        return max(a[p] + histag[x].ad, histag[x].k);

    pushdown(x);
    int mid = (nl + nr) >> 1;
    return p <= mid ? queryhismax(ls(x), nl, mid, p) : queryhismax(rs(x), mid + 1, nr, p);
}
} // namespace SMT

signed main() {
    scanf("%d%d", &n, &m);

    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i);

    while (m--) {
        int op;
        scanf("%d", &op);

        if (op == 1) {
            int l, r, k;
            scanf("%d%d%d", &l, &r, &k);
            SMT::update(1, 1, n, l, r, {k, 0});
        } else if (op == 2) {
            int l, r, k;
            scanf("%d%d%d", &l, &r, &k);
            SMT::update(1, 1, n, l, r, {-k, 0});
        } else if (op == 3) {
            int l, r, k;
            scanf("%d%d%d", &l, &r, &k);
            SMT::update(1, 1, n, l, r, {-inf, k});
        } else if (op == 4) {
            int k;
            scanf("%d", &k);
            printf("%lld\n", SMT::querymax(1, 1, n, k));
        }
        else {
            int k;
            scanf("%d", &k);
            printf("%lld\n", SMT::queryhismax(1, 1, n, k));
        }
    }

    return 0;
}

猫树

基于线段树的序列结构,特点:

  • 序列静态(如果是尾部插入应该也是可以的)。
  • 区间查询复杂度优秀。
  • 维护的信息需要支持结合律和快速合并。

对于线段树上的一个区间 \([l, r]\) ,可以将其分为 \([l, mid]\)\([mid + 1, r]\) 。考虑从中点出发,分别向两边遍历并维护要处理的信息。预处理部分时空复杂度 \(O(n \log n)\) (忽略合并带来的时空开销)。

void build(int x, int l, int r, int d) {
    if (l == r) {
        pos[l] = x;
        return;
    }

    int mid = (l + r) >> 1;
    nd[d][mid] = Node(a[mid]);

    for (int i = mid - 1; i >= l; --i)
        nd[d][i] = Node(a[i]) + nd[d][i + 1];

    nd[d][mid + 1] = Node(a[mid + 1]);

    for (int i = mid + 2; i <= r; ++i)
        nd[d][i] = nd[d][i - 1] + Node(a[i]);

    build(ls(x), l, mid, d + 1), build(rs(x), mid + 1, r, d + 1);
}

查询时找到最深的包含 \([ql, qr]\) 的区间,那么只要将 \([ql, mid]\)\([mid + 1, qr]\) 的信息合并即可。时间复杂度为合并信息的复杂度。

接下来考虑如何找到最深的包含 \([ql, qr]\) 的区间,不难发现其为 \(ql, qr\) 两个叶子的 LCA。考虑堆式建树(将序列补成 \(2\) 的幂然后建树),此时线段树上两个点的 LCA 编号,就是两个点二进制下的 LCP ,而 \(x, y\) 在二进制下的 LCP 即为为 x >> __lg(x ^ y) ,于是可以 \(O(1)\) 找到该点。

inline Node query(int l, int r) {
    if (l > r)
        return Node();

    if (l == r)
        return Node(a[l]);

    int d = __lg(pos[l]) - __lg(pos[l] ^ pos[r]);
    return nd[d][l] + nd[d][r];
}

事实上猫树可以扩展到树上,支持快速静态链信息查询。考虑对树建立点分树结构,预处理每个点到重心的信息(是否包括重心、方向性都要考虑),然后每次询问就可以拆为两个已经预处理信息的合并。询问时找到这条链中点分树上最浅的重心合并信息即可,这个可以 ST 表预处理后 \(O(1)\) 查询。时间复杂度和序列的情况是一样的。

GSS1 - Can you answer these queries I

给定 \(a_{1 \sim n}\)\(m\) 次查询区间最大子段和。

\(n, m \le 5 \times 10^4\)

考虑如何合并信息,一个区间的答案只会在 \([ql, mid]\)\([mid + 1, qr]\) 以及跨过 \(mid\) 的区间的子段和中产生。对于前两者可以直接预处理,对于第三个可以预处理最大前/后缀和得到。

时间复杂度 \(O(n \log n + q)\)

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 7, LOGN = 17;

int a[N];

int n, m;

namespace CatTree {
struct Node {
    int ans, mxsum;
} nd[LOGN][N];

int pos[N];

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

void build(int x, int l, int r, int d) {
    if (l == r) {
        pos[l] = x;
        return;
    }

    int mid = (l + r) >> 1;
    nd[d][mid] = (Node) {a[mid], a[mid]};

    for (int i = mid - 1, sum = a[mid], res = a[mid]; i >= l; --i) {
        sum += a[i], res = max(res, 0) + a[i];
        nd[d][i] = (Node) {max(res, nd[d][i + 1].ans), max(sum, nd[d][i + 1].mxsum)};
    }

    nd[d][mid + 1] = (Node) {a[mid + 1], a[mid + 1]};

    for (int i = mid + 2, sum = a[mid + 1], res = a[mid + 1]; i <= r; ++i) {
        sum += a[i], res = max(res, 0) + a[i];
        nd[d][i] = (Node) {max(res, nd[d][i - 1].ans), max(sum, nd[d][i - 1].mxsum)};
    }

    build(ls(x), l, mid, d + 1), build(rs(x), mid + 1, r, d + 1);
}

inline int query(int l, int r) {
    if (l == r)
        return a[l];

    int d = __lg(pos[l]) - __lg(pos[l] ^ pos[r]);
    return max(max(nd[d][l].ans, nd[d][r].ans), nd[d][l].mxsum + nd[d][r].mxsum);
}
} // namespace CatTree

signed main() {
    scanf("%d", &n);

    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i);

    CatTree::build(1, 1, n = 1 << (__lg(n - 1) + 1), 1);
    scanf("%d", &m);

    while (m--) {
        int l, r;
        scanf("%d%d", &l, &r);
        printf("%d\n", CatTree::query(l, r));
    }

    return 0;
}

GSS5 - Can you answer these queries V

给定 \(a_{1 \sim n}\)\(m\) 次查询左端点在\([x_1, y_1]\)之间、右端点在\([x_2, y_2]\)之间的区间 \([l, r]\) 的最大子段和。

\(n, m \le 10^4\)

对于 \(y_1 < x_2\) 的情况,答案即为 \([x_1, y_1]\) 的最大后缀和加上 \((y_1, x_2)\) 的区间和加上 \([x_2, y_2]\) 的最大前缀和。

否则,由于区间有重叠,分类讨论:

  • \(l, r\in [x_2, y_1]\) :答案为 \([x_2, y_1]\) 的最大子段和。
  • \(l \in [x_1, x_2), r \in [x_2, y_1]\) :答案为 \([x_1, x_2)\) 的最大后缀和加 \([x_2, y_1]\) 的最大前缀和。
  • \(l \in [x_2, y_1], r \in (y_1, y_2]\) :答案为 \([x_2, y_1]\) 的最大后缀和加 \((y_1, y_2]\) 的最大前缀和。
  • \(l \in [x_1, x_2), r \in (y_1, y_2]\) :答案为 \([x_1, x_2)\) 的最大后缀和加 \([x_2, y_1]\) 的区间和加 \((y_1, y_2]\) 的最大前缀和。

发现只要处理区间最大子段和、区间和、区间最大前缀和、区间最大后缀和即可,不难用猫树做到 \(O(n \log n + q)\)

#include <bits/stdc++.h>
using namespace std;
const int N = 2e4 + 7, LOGN = 15;

int a[N];

int n, m;

namespace CatTree {
struct Node {
    int sum, ans, premxsum, sufmxsum;

    inline Node(const int k = 0) : sum(k), ans(k), premxsum(k), sufmxsum(k) {}

    inline friend Node operator + (const Node &a, const Node &b) {
        Node c;
        c.sum = a.sum + b.sum;
        c.ans = max(max(a.ans, b.ans), a.sufmxsum + b.premxsum);
        c.premxsum = max(a.premxsum, a.sum + b.premxsum);
        c.sufmxsum = max(b.sufmxsum, b.sum + a.sufmxsum);
        return c;
    }
} nd[LOGN][N];

int pos[N];

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

void build(int x, int l, int r, int d) {
    if (l == r) {
        pos[l] = x;
        return;
    }

    int mid = (l + r) >> 1;
    nd[d][mid] = Node(a[mid]);

    for (int i = mid - 1; i >= l; --i)
        nd[d][i] = Node(a[i]) + nd[d][i + 1];

    nd[d][mid + 1] = Node(a[mid + 1]);

    for (int i = mid + 2; i <= r; ++i)
        nd[d][i] = nd[d][i - 1] + Node(a[i]);

    build(ls(x), l, mid, d + 1), build(rs(x), mid + 1, r, d + 1);
}

inline Node query(int l, int r) {
    if (l > r)
        return Node();

    if (l == r)
        return Node(a[l]);

    int d = __lg(pos[l]) - __lg(pos[l] ^ pos[r]);
    return nd[d][l] + nd[d][r];
}
} // namespace CatTree

signed main() {
    int T;
    scanf("%d", &T);

    while (T--) {
        scanf("%d", &n);

        for (int i = 1; i <= n; ++i)
            scanf("%d", a + i);

        CatTree::build(1, 1, n = 1 << (__lg(n - 1) + 1), 1);
        scanf("%d", &m);

        while (m--) {
            int l1, r1, l2, r2;
            scanf("%d%d%d%d", &l1, &r1, &l2, &r2);

            if (r1 < l2)
                printf("%d\n", CatTree::query(l1, r1).sufmxsum + CatTree::query(r1 + 1, l2 - 1).sum + 
                    CatTree::query(l2, r2).premxsum);
            else
                printf("%d\n", max(max(max(CatTree::query(l2, r1).ans,
                    CatTree::query(l1, l2 - 1).sufmxsum + CatTree::query(l2, r1).premxsum),
                    CatTree::query(l2, r1).sufmxsum + CatTree::query(r1 + 1, r2).premxsum),
                    CatTree::query(l1, l2 - 1).sufmxsum + CatTree::query(l2, r1).sum + CatTree::query(r1 + 1, r2).premxsum));
        }
    }

    return 0;
}

KTT

\(n\) 个一次函数 \(k_i x_i + b_i\) ,其中 \(k_i \geq 0\) 。初始时 \(x_i = 0\) ,每次操作为对 \(x\) 区间加正数 \(k\) 或查询 \(\max_{i = l}^r k_i x_i + b_i\)

可以发现区间加时每个点乘的倍数不一样,并且一个区间在增量达到一定程度时,最值取到的位置会发生变化。

考虑势能线段树,对每个点维护 \(x_i = 0\) 时值最大的函数,额外维护一个阈值 \(t\) ,表示增量至少达到 \(t\) 时区间内有一个函数的最大值会改变。

  • 信息更新:pushup 的时候对两个儿子的 \(t\)\(\min\) ,然后再算儿子的最大函数的交点横坐标即可。
  • 区间修改:递归到 \(k \le t\) 的子树打标记即可。
  • 区间查询:直接查询。

势能分析得到的时间复杂度为 \(O((n + m) \log^3 n)\) (三只 \(\log\) 是修改的,查询只要一只 \(\log\) )。

称这种结构为 KTT (Kinetic Tournament Tree)。KTT 的实际表现很优秀,基本都跑得很不满,和两只 \(\log\) 差不多。

本质就是对每个点维护一个一次函数,其中斜率 \(k\) 表示同时变化的数的数量,截距 \(b\) 表示值,然后不断在线段树上向上更新。

struct Line {
    ll k, b;

    inline Line operator + (const Line &rhs) const {
        return (Line){k + rhs.k, b + rhs.b};
    }

    inline friend pair<Line, ll> cmp(Line a, Line b) { // 返回 x = 0 的较大值以及反超点
        if (a.k == b.k ? a.b < b.b : a.k < b.k)
            swap(a, b);

        return a.b >= b.b ? make_pair(a, inf) : make_pair(b, (b.b - a.b) / (a.k - b.k));
    }
};

P5693 EI 的第六分块

给出整数序列 \(a_{1 \sim n}\)\(m\) 次操作,操作有:

  • 1 l r k :给区间 \([l,r]\) 中每个数加上 \(k \in [1, 10^6]\)
  • 2 l r :查询区间 \([l,r]\) 的最大子段和(可以为空)。

\(n, m \le 4 \times 10^5\)

先不考虑区间加,对于区间最大子段和问题,通常的处理思路是在线段树的每个节点维护 \((lmx, rmx, sum, ans)\)

由于这几项都与选取的区间长度有关,考虑额外维护一个阈值 \(t\) 表示这些信息第一次变化的最小增量,然后直接上 KTT 即可。

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const ll inf = 0x3f3f3f3f3f3f3f3f;
const int N = 4e5 + 7;
 
int a[N];
 
int n, m;

namespace KTT {
struct Line {
    ll k, b;

    inline Line operator + (const Line &rhs) const {
        return (Line){k + rhs.k, b + rhs.b};
    }

    inline friend pair<Line, ll> cmp(Line a, Line b) {
        if (a.k == b.k ? a.b < b.b : a.k < b.k)
            swap(a, b);

        return a.b >= b.b ? make_pair(a, inf) : make_pair(b, (b.b - a.b) / (a.k - b.k));
    }
};

struct Node {
    Line lmx, rmx, sum, ans;

    ll t;

    inline friend Node operator + (Node a, Node b) {
        Node c;
        c.t = min(a.t, b.t), c.sum = a.sum + b.sum;

        auto res = cmp(a.lmx, a.sum + b.lmx);
        c.lmx = res.first, c.t = min(c.t, res.second);

        res = cmp(b.rmx, b.sum + a.rmx);
        c.rmx = res.first, c.t = min(c.t, res.second);

        res = cmp(a.ans, b.ans);
        c.ans = res.first, c.t = min(c.t, res.second);

        res = cmp(c.ans, a.rmx + b.lmx);
        c.ans = res.first, c.t = min(c.t, res.second);

        return c;
    }
} nd[N << 2];

ll tag[N << 2];

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

inline void spread(int x, ll k) {
    nd[x].lmx.b += nd[x].lmx.k * k;
    nd[x].rmx.b += nd[x].rmx.k * k;
    nd[x].sum.b += nd[x].sum.k * k;
    nd[x].ans.b += nd[x].ans.k * k;
    tag[x] += k, nd[x].t -= k;
}

inline void pushdown(int x) {
    if (tag[x])
        spread(ls(x), tag[x]), spread(rs(x), tag[x]), tag[x] = 0;
}

void build(int x, int l, int r) {
    if (l == r) {
        nd[x] = (Node){(Line){1, a[l]}, (Line){1, a[l]}, (Line){1, a[l]}, (Line){1, a[l]}, inf};
        return;
    }

    int mid = (l + r) >> 1;
    build(ls(x), l, mid), build(rs(x), mid + 1, r);
    nd[x] = nd[ls(x)] + nd[rs(x)];
}

void maintain(int x, int l, int r, ll k) {
    if (k <= nd[x].t) {
        spread(x, k);
        return;
    }

    pushdown(x);
    int mid = (l + r) >> 1;
    maintain(ls(x), l, mid, k), maintain(rs(x), mid + 1, r, k);
    nd[x] = nd[ls(x)] + nd[rs(x)];
}

void update(int x, int nl, int nr, int l, int r, ll k) {
    if (l <= nl && nr <= r) {
        maintain(x, nl, nr, k);
        return;
    }

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (l <= mid)
        update(ls(x), nl, mid, l, r, k);

    if (r > mid)
        update(rs(x), mid + 1, nr, l, r, k);

    nd[x] = nd[ls(x)] + nd[rs(x)];
}

Node query(int x, int nl, int nr, int l, int r) {
    if (l <= nl && nr <= r)
        return nd[x];

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (r <= mid)
        return query(ls(x), nl, mid, l, r);
    else if (l > mid)
        return query(rs(x), mid + 1, nr, l, r);
    else
        return query(ls(x), nl, mid, l, r) + query(rs(x), mid + 1, nr, l, r);
}
} // namespace KTT

signed main() {
    scanf("%d%d", &n, &m);

    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i);

    KTT::build(1, 1, n);

    while (m--) {
        int op, l, r;
        scanf("%d%d%d", &op, &l, &r);

        if (op == 1) {
            int k;
            scanf("%d", &k);
            KTT::update(1, 1, n, l, r, k);
        } else
            printf("%lld\n", max(0ll, KTT::query(1, 1, n, l, r).ans.b));
    }

    return 0;
}

P6792 [SNOI2020] 区间和

给出整数序列 \(a_{1 \sim n}\)\(m\) 次操作,操作有:

  • 0 l r k :给区间 \([l,r]\) 中每个数与 \(k\)\(\max\)
  • 1 l r :查询区间 \([l,r]\) 的最大子段和(可以为空)。

\(n \le 10^5\)\(m \le 2 \times 10^5\)

同时维护吉司机线段树和 KTT 。由于吉司机线段树的打标记是对整个区间的最小值打标记,因此 KTT 上的斜率应定义为最小值个数。

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const ll inf = 0x3f3f3f3f3f3f3f3f;
const int N = 1e5 + 7;

int a[N];

int n, m;

namespace KTT {
struct Line {
    ll k, b;

    inline Line operator + (const Line &rhs) const {
        return (Line) {k + rhs.k, b + rhs.b};
    }

    inline friend pair<Line, ll> cmp(Line a, Line b) {
        if (a.k == b.k ? a.b < b.b : a.k < b.k)
            swap(a, b);

        return a.b >= b.b ? make_pair(a, inf) : make_pair(b, (b.b - a.b) / (a.k - b.k));
    }
};

struct Node {
    Line lmx, rmx, sum, ans;

    ll t;

    inline Node reset() {
        Node res = *this;
        res.lmx.k = res.rmx.k = res.sum.k = res.ans.k = 0;
        return res;
    }
    
    inline friend Node operator + (const Node &a, const Node &b) {
        Node c;
        c.t = min(a.t, b.t), c.sum = a.sum + b.sum;

        auto res = cmp(a.lmx, a.sum + b.lmx);
        c.lmx = res.first, c.t = min(c.t, res.second);

        res = cmp(b.rmx, b.sum + a.rmx);
        c.rmx = res.first, c.t = min(c.t, res.second);

        res = cmp(a.ans, b.ans);
        c.ans = res.first, c.t = min(c.t, res.second);

        res = cmp(c.ans, a.rmx + b.lmx);
        c.ans = res.first, c.t = min(c.t, res.second);
        return c;
    }
} nd[N << 2];

ll tag[N << 2], mn[N << 2], sec[N << 2];

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

inline void pushup(int x) {
    if (mn[ls(x)] == mn[rs(x)]) {
        mn[x] = mn[ls(x)], sec[x] = min(sec[ls(x)], sec[rs(x)]);
        nd[x] = nd[ls(x)] + nd[rs(x)];
    } else if (mn[ls(x)] < mn[rs(x)]) {
        mn[x] = mn[ls(x)], sec[x] = min(sec[ls(x)], mn[rs(x)]);
        nd[x] = nd[ls(x)] + nd[rs(x)].reset();
    } else {
        mn[x] = mn[rs(x)], sec[x] = min(mn[ls(x)], sec[rs(x)]);
        nd[x] = nd[ls(x)].reset() + nd[rs(x)];
    }
}

void build(int x, int l, int r) {
    tag[x] = -inf;

    if (l == r) {
        nd[x] = (Node) {(Line) {1, a[l]}, (Line) {1, a[l]}, (Line) {1, a[l]}, (Line) {1, a[l]}, inf};
        mn[x] = a[l], sec[x] = inf;
        return;
    }

    int mid = (l + r) >> 1;
    build(ls(x), l, mid), build(rs(x), mid + 1, r);
    pushup(x);
}

inline void spread(int x, ll k) {
    if (k <= mn[x])
        return;

    ll c = k - mn[x];
    mn[x] = k, tag[x] = max(tag[x], k);
    nd[x].t -= c;
    nd[x].lmx.b += nd[x].lmx.k * c;
    nd[x].rmx.b += nd[x].rmx.k * c;
    nd[x].sum.b += nd[x].sum.k * c;
    nd[x].ans.b += nd[x].ans.k * c;
}

inline void pushdown(int x) {
    if (tag[x] != -inf)
        spread(ls(x), tag[x]), spread(rs(x), tag[x]), tag[x] = -inf;
}

void maintain(int x, int l, int r, ll k) {
    if (k - mn[x] <= nd[x].t) {
        spread(x, k);
        return;
    }

    pushdown(x);
    int mid = (l + r) >> 1;
    maintain(ls(x), l, mid, k), maintain(rs(x), mid + 1, r, k);
    pushup(x);
}

void update(int x, int nl, int nr, int l, int r, ll k) {
    if (mn[x] >= k)
        return;

    if (l <= nl && nr <= r && k < sec[x]) {
        maintain(x, nl, nr, k);
        return;
    }

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (l <= mid)
        update(ls(x), nl, mid, l, r, k);

    if (r > mid)
        update(rs(x), mid + 1, nr, l, r, k);

    pushup(x);
}

Node query(int x, int nl, int nr, int l, int r) {
    if (l <= nl && nr <= r)
        return nd[x];

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (r <= mid)
        return query(ls(x), nl, mid, l, r);
    else if (l > mid)
        return query(rs(x), mid + 1, nr, l, r);
    else
        return query(ls(x), nl, mid, l, r) + query(rs(x), mid + 1, nr, l, r);
}
} // namespace KTT

signed main() {
    scanf("%d%d", &n, &m);

    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i);

    KTT::build(1, 1, n);

    while (m--) {
        int op, l, r;
        scanf("%d%d%d", &op, &l, &r);

        if (op)
            printf("%lld\n", max(KTT::query(1, 1, n, l, r).ans.b, 0ll));
        else {
            int k;
            scanf("%d", &k);
            KTT::update(1, 1, n, l, r, k);
        }
    }

    return 0;
}

CF1178G The Awesomest Vertex

给定一棵树,每个点有两个值 \(a, b\) ,定义 \(A, B\) 表示其到根链上的 \(a, b\) 的和,一个点的权值定义为 \(|A| \times |B|\)\(m\) 次操作,操作有:

  • 1 x k :将 \(a_x\) 加上正数 \(k\)
  • 2 x :求 \(x\) 子树内的最大权值。

\(n \le 2 \times 10^5\)\(q \le 10^5\)

首先不难将两个转化为区间加和区间查最大权值。由于 \(B\) 固定,因此考虑 \(A\) 的维护。

直接维护这个 \(|A|\) 是困难的,但是注意到 \(|A| = \max(A, -A)\) ,得到 \(|A| \times |B| = \max(A \times B, A \times -B)\) 。因此可以开两棵 KTT ,一棵维护 \(A \times B\) ,初始斜率为 \(B\) ;另一棵维护 \(A \times -B\) ,初始斜率为 \(-B\) 。剩下就是套板子了。

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const ll inf = 0x3f3f3f3f3f3f3f3f;
const int N = 2e5 + 7;

struct Graph {
    vector<int> e[N];
    
    inline void insert(int u, int v) {
        e[u].emplace_back(v);
    }
} G;

ll a[N], b[N];
int fa[N], in[N], out[N], id[N];

int n, m, dfstime;

void dfs(int u) {
    id[in[u] = ++dfstime] = u, a[u] += a[fa[u]], b[u] += b[fa[u]];

    for (int v : G.e[u])
        dfs(v);

    out[u] = dfstime;
}

struct KTT {
    struct Line {
        ll k, b;

        inline friend pair<Line, ll> cmp(Line a, Line b) {
            if (a.k == b.k ? a.b < b.b : a.k < b.k)
                swap(a, b);

            return a.b >= b.b ? make_pair(a, inf) : make_pair(b, (b.b - a.b) / (a.k - b.k));
        }
    };

    struct Node {
        Line mx;

        ll t;

        inline friend Node operator + (const Node &a, const Node &b) {
            auto res = cmp(a.mx, b.mx);
            return (Node) {res.first, min(min(a.t, b.t), res.second)};
        }
    } nd[N << 2];

    ll tag[N << 2];

    inline int ls(int x) {
        return x << 1;
    }

    inline int rs(int x) {
        return x << 1 | 1;
    }

    inline void spread(int x, ll k) {
        tag[x] += k, nd[x].t -= k, nd[x].mx.b += nd[x].mx.k * k;
    }

    inline void pushdown(int x) {
        if (tag[x])
            spread(ls(x), tag[x]), spread(rs(x), tag[x]), tag[x] = 0;
    }

    void build(int x, int l, int r, int op) {
        if (l == r) {
            nd[x] = (Node) {(Line) {b[id[l]] * op, a[id[l]] * b[id[l]] * op}, inf};
            return;
        }

        int mid = (l + r) >> 1;
        build(ls(x), l, mid, op), build(rs(x), mid + 1, r, op);
        nd[x] = nd[ls(x)] + nd[rs(x)];
    }

    void maintain(int x, int l, int r, ll k) {
        if (k <= nd[x].t) {
            spread(x, k);
            return;
        }

        pushdown(x);
        int mid = (l + r) >> 1;
        maintain(ls(x), l, mid, k), maintain(rs(x), mid + 1, r, k);
        nd[x] = nd[ls(x)] + nd[rs(x)];
    }

    void update(int x, int nl, int nr, int l, int r, ll k) {
        if (l <= nl && nr <= r) {
            maintain(x, nl, nr, k);
            return;
        }

        pushdown(x);
        int mid = (nl + nr) >> 1;

        if (l <= mid)
            update(ls(x), nl, mid, l, r, k);

        if (r > mid)
            update(rs(x), mid + 1, nr, l, r, k);

        nd[x] = nd[ls(x)] + nd[rs(x)];
    }

    Node query(int x, int nl, int nr, int l, int r) {
        if (l <= nl && nr <= r)
            return nd[x];

        pushdown(x);
        int mid = (nl + nr) >> 1;

        if (r <= mid)
            return query(ls(x), nl, mid, l, r);
        else if (l > mid)
            return query(rs(x), mid + 1, nr, l, r);
        else
            return query(ls(x), nl, mid, l, r) + query(rs(x), mid + 1, nr, l, r);
    }
} ktt1, ktt2;

signed main() {
    scanf("%d%d", &n, &m);

    for (int i = 2; i <= n; ++i)
        scanf("%d", fa + i), G.insert(fa[i], i);

    for (int i = 1; i <= n; ++i)
        scanf("%lld", a + i);

    for (int i = 1; i <= n; ++i)
        scanf("%lld", b + i);

    dfs(1), ktt1.build(1, 1, n, 1), ktt2.build(1, 1, n, -1);

    while (m--) {
        int op, x;
        scanf("%d%d", &op, &x);

        if (op == 1) {
            int k;
            scanf("%d", &k);
            ktt1.update(1, 1, n, in[x], out[x], k);
            ktt2.update(1, 1, n, in[x], out[x], k);
        } else
            printf("%lld\n", max(ktt1.query(1, 1, n, in[x], out[x]).mx.b, 
                ktt2.query(1, 1, n, in[x], out[x]).mx.b));
    }

    return 0;
}

P9288 [ROI 2018] Innophone

给定 \(n\)\(x_i, y_i\) ,选取 \(a, b\) 最大化:

\[\sum_{i = 1}^n [a \le x_i] a + [a > x_i] [b \le y_i] b \]

\(n \le 1.5 \times 10^5\)

考虑强制令 \(a, b\) 取到某个 \(x, y\) ,容易发现这样不会使得答案更劣。

考虑扫描线,从小到大枚举 \(a\) ,此时 \(a\) 的贡献是容易算出的。对于剩下 \(a > x\) 的位置,\(a\) 变大时此部分贡献就会加入某些位置,记 \(rk_i\) 表示 \(y \ge y_i\) 的位置数量,接下来考虑动态维护 \(y_i \times rk_i\) 的最大值。考虑以 \(y\) 为下标建立 KTT,那么加入一个数相当于对一个前缀加上一个斜率为 \(1\) 的一次函数,每个点的初始斜率为 \(y_i\) ,初始截距为 \(y_i \times rk_i\) ,不难直接维护。

还有个差不多的题目:CF436F Banners

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const ll inf = 0x3f3f3f3f3f3f3f3f;
const int N = 1.5e5 + 7;

struct Node {
    int x, y;

    inline bool operator < (const Node &rhs) const {
        return x < rhs.x;
    }
} nd[N];

vector<int> vec;

int n;

namespace KTT {
struct Line {
    ll k, b;

    inline friend pair<Line, ll> cmp(Line a, Line b) {
        if (a.k == b.k ? a.b < b.b : a.k < b.k)
            swap(a, b);

        return a.b >= b.b ? make_pair(a, inf) : make_pair(b, (b.b - a.b) / (a.k - b.k));
    }
};

struct Node {
    Line mx;

    ll t;

    inline friend Node operator + (const Node &a, const Node &b) {
        Node c;
        auto res = cmp(a.mx, b.mx);
        c.mx = res.first, c.t = min(min(a.t, b.t), res.second);
        return c;
    }
} nd[N << 2];

ll tag[N << 2];

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

inline void spread(int x, ll k) {
    nd[x].mx.b += nd[x].mx.k * k, nd[x].t -= k, tag[x] += k;
}

inline void pushdown(int x) {
    if (tag[x])
        spread(ls(x), tag[x]), spread(rs(x), tag[x]), tag[x] = 0;
}

void build(int x, int l, int r) {
    if (l == r) {
        nd[x].mx.k = vec[l], nd[x].t = inf;
        return;
    }

    int mid = (l + r) >> 1;
    build(ls(x), l, mid), build(rs(x), mid + 1, r);
    nd[x] = nd[ls(x)] + nd[rs(x)];
}

void maintain(int x, int l, int r, ll k) {
    if (k <= nd[x].t) {
        spread(x, k);
        return;
    }

    pushdown(x);
    int mid = (l + r) >> 1;
    maintain(ls(x), l, mid, k), maintain(rs(x), mid + 1, r, k);
    nd[x] = nd[ls(x)] + nd[rs(x)];
}

void update(int x, int nl, int nr, int l, int r, ll k) {
    if (l <= nl && nr <= r) {
        maintain(x, nl, nr, k);
        return;
    }

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (l <= mid)
        update(ls(x), nl, mid, l, r, k);

    if (r > mid)
        update(rs(x), mid + 1, nr, l, r, k);

    nd[x] = nd[ls(x)] + nd[rs(x)];
}
} // namespace KTT

signed main() {
    scanf("%d", &n);

    for (int i = 1; i <= n; ++i)
        scanf("%d%d", &nd[i].x, &nd[i].y), vec.emplace_back(nd[i].y);

    sort(nd + 1, nd + n + 1);
    sort(vec.begin(), vec.end()), vec.erase(unique(vec.begin(), vec.end()), vec.end());
    int m = vec.size() - 1;
    KTT::build(1, 0, m);
    ll ans = 0;

    for (int i = 1; i <= n; ++i) {
        if (i == 1 || nd[i].x != nd[i - 1].x)
            ans = max(ans, 1ll * (n - i + 1) * nd[i].x + KTT::nd[1].mx.b);

        KTT::update(1, 0, m, 0, nd[i].y = lower_bound(vec.begin(), vec.end(), nd[i].y) - vec.begin(), 1);
    }

    printf("%lld", max(ans, KTT::nd[1].mx.b));
    return 0;
}

CF1830F The Third Grace

给定一个数轴上的 \(n\) 个区间和 \(m\) 个点,第 \(i\) 个区间覆盖坐标 \([l_i, r_i]\),第 \(i\) 个点在坐标 \(i\) 处,并且具有系数 \(p_i\)

定义一个区间的价值为:

  • 若区间内没有被激活的点,则代价为 \(0\)
  • 否则价值为在区间内坐标最大的被激活点的系数。

最初所有点都未激活,选择一些点激活,最大化所有区间的价值和。

\(n, m \le 10^6\)

\(f_i\) 表示考虑前 \(i\) 个点,且激活第 \(i\) 个点的最大价值,为了方便此时的价值不包括第 \(i\) 个点的贡献。则:

\[f_i = \max_{j < i} \{ f_j + p_j \times w(j, i) \} \]

其中 \(w(j, i)\) 表示 \(l \le j \le r < i\) 的区间数量,直接做可以做到 \(O(n^2)\)

考虑用数据结构维护 \(f_j + p_j \times w(j, i)\) ,每次 \(i \to i + 1\) 的时候,就把所有 \(r = i\) 的区间的 \([l, r]\) 的值加上相应的 \(p\) ,然后还要支持单点修改,直接上 KTT 即可。

时间复杂度 \(O(m \log^3 m)\) ,神奇的是 KTT 的三个 \(\log\) 能冲过 \(10^6\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const ll inf = 0x3f3f3f3f3f3f3f3f;
const int N = 1e6 + 7;

struct Interval {
    int l, r;

    inline bool operator < (const Interval &rhs) const {
        return r < rhs.r;
    }
} a[N];

ll f[N];
int p[N];

int n, m;

namespace KTT {
struct Line {
    ll k, b;

    inline friend pair<Line, ll> cmp(Line a, Line b) {
        if (a.k == b.k ? a.b < b.b : a.k < b.k)
            swap(a, b);

        return a.b >= b.b ? make_pair(a, inf) : make_pair(b, (b.b - a.b) / (a.k - b.k));
    }
};

struct Node {
    Line mx;

    ll t;

    inline friend Node operator + (const Node &a, const Node &b) {
        auto res = cmp(a.mx, b.mx);
        return (Node) {res.first, min(min(a.t, b.t), res.second)};
    }
} nd[N << 2];

ll tag[N << 2];

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

inline void spread(int x, ll k) {
    nd[x].mx.b += nd[x].mx.k * k, nd[x].t -= k, tag[x] += k;
}

inline void pushdown(int x) {
    if (tag[x])
        spread(ls(x), tag[x]), spread(rs(x), tag[x]), tag[x] = 0;
}

void build(int x, int l, int r) {
    tag[x] = 0;

    if (l == r) {
        nd[x] = (Node){(Line) {p[l], 0}, inf};
        return;
    }

    int mid = (l + r) >> 1;
    build(ls(x), l, mid), build(rs(x), mid + 1, r);
    nd[x] = nd[ls(x)] + nd[rs(x)];
}

void modify(int x, int l, int r, int p, ll k) {
    if (l == r) {
        nd[x].mx.b = k;
        return;
    }

    pushdown(x);
    int mid = (l + r) >> 1;

    if (p <= mid)
        modify(ls(x), l, mid, p, k);
    else
        modify(rs(x), mid + 1, r, p, k);

    nd[x] = nd[ls(x)] + nd[rs(x)];
}

void maintain(int x, int l, int r, ll k) {
    if (k <= nd[x].t) {
        spread(x, k);
        return;
    }

    pushdown(x);
    int mid = (l + r) >> 1;
    maintain(ls(x), l, mid, k), maintain(rs(x), mid + 1, r, k);
    nd[x] = nd[ls(x)] + nd[rs(x)];
}

void update(int x, int nl, int nr, int l, int r, ll k) {
    if (l <= nl && nr <= r) {
        maintain(x, nl, nr, k);
        return;
    }

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (l <= mid)
        update(ls(x), nl, mid, l, r, k);

    if (r > mid)
        update(rs(x), mid + 1, nr, l, r, k);

    nd[x] = nd[ls(x)] + nd[rs(x)];
}
} // namespace KTT

signed main() {
    int T;
    scanf("%d", &T);

    while (T--) {
        scanf("%d%d", &n, &m);

        for (int i = 1; i <= n; ++i)
            scanf("%d%d", &a[i].l, &a[i].r);

        sort(a + 1, a + n + 1);

        for (int i = 1; i <= m; ++i)
            scanf("%d", p + i);

        KTT::build(1, 0, m);

        for (int i = 1, j = 1; i <= m; ++i) {
            KTT::modify(1, 0, m, i, f[i] = KTT::nd[1].mx.b);

            for (; j <= n && a[j].r == i; ++j)
                KTT::update(1, 0, m, a[j].l, i, 1);
        }

        printf("%lld\n", KTT::nd[1].mx.b);
    }

    return 0;
}
posted @ 2025-02-06 20:13  wshcl  阅读(65)  评论(0)    收藏  举报