KTT 笔记

KTT 简介

KTT 是 2020 年集训队论文 中,由李白天写的一篇文章中的部分内容,文章名称为 “浅谈函数最值的动态维护”。其中 KTT 十分重要,他是一个线段树的扩展,通常情况下他用以解决下面所述问题,更高级和更有扩展性的用法难度过高,很少出现,并且我学不会我还没学会。

本文是我学习了这篇文章后总结的产物,带入了个人理解,并用更简单但不一定完全严谨的方法代替了文章中难以理解的部分,有问题会及时修正。

解决问题

常规线段树被用来解决序列上区间修改、区间求和或最值问题,但是这里的序列中每个部分都存储的是一个值,每次修改也是一个区间修改同样的值。

如果区间修改,但修改的值不一样怎么办?我们让序列中每个部分不存储一个值,而是一个函数 \(a_i x_i + b_i\)

询问区间和是简单的,现在讨论如何询问区间最值。我们需要给操作加一些限制,得以用 KTT 解决他,下面是具体操作:

  • \(1 \ \ \ l \ \ \ r \ \ \ k\):将区间 \(l\)\(r\)\(x_i\) 增加 \(k\)其中 \(k > 0\)
  • \(2 \ \ \ l \ \ \ r\):询问区间 \(l\)\(r\)最大值

这里我们规定了 \(x_i\) 只加不减以及每次只询问最大值,下面会说明原因。事实上,这也是 KTT 的最大限制。

大体思路

首先,我们不在序列中每个部分存储 \(a_i, b_i, x_i\),而是存成 \(a_i\)\(b_i\),每次修改其实就是 \(b_i += a_i \times k\)。这意味着每个位置的具体值就是 \(b_i\)

然后,在线段树中我们要维护一个极其关键的值,\(intr\),我们称其为交替阈值。其中,交替是 KTT 一个重要的操作,他使得 KTT 可以处理最大值的询问,也给 KTT 解决的问题带来了限制。

交替:本问题与常规线段树问题最大的不同是,即便是对一整段区间修改同样的东西,这段区间的最大值也有可能改变,这就叫交替。我们不关注交替之后谁是最大值,因为我们可以直接 pushdown 加 pushup 计算出来,我们只关注什么时候会交替,这时候就需要交替阈值这个东西了。

具体实现

我们对于每个节点维护一下几个值:

  • \({a_x, b_x}\):表示这一节点这一段的 \(b_x\) 最大的位置的 \(a_x\)\(b_x\)
  • \(lz_x\):懒标记。
  • \(intr_x\):交替阈值,每次这个节点的整个区间进行操作时,也就是 \(x_i\) 都增加 \(k\),我们就把交替阈值减少 \(k\),当他小于等于 \(0\) 时表示这个节点的这一段的最大值该交替了。

pushdown

根据维护的值的定义,直接操作就可以:

void pushdown(int x){
    for(int y : {x << 1, x << 1 | 1}){
        lz[y] += lz[x];
        s[y].b += lz[x] * s[y].a;
        intr[y] -= lz[x];
    }
    lz[x] = 0;
    return;
}

pushup

针对 \({a_x, b_x}\) 的操作很简单,直接找 \(b_x\) 较大的就行,但 pushup 还需要对 \(intr_x\) 更新,这时候有三种可能性,\(intr_ls\)\(intr_rs\) 和左右最大值交替的时间。为了找到这个时间,我们定义一个函数 inter,他就是找两个一次函数的交点并向上取整。特别的,如果这个交点是负数或没有交点,那么返回正无穷,因为显然这两个函数的较大值不会交替。

注意在 C++ 中,除法是向零取整,所以我们要让除数和被除数都是正数,具体方法可以参考代码。

int inter(node x, node y){
    int da = x.a - y.a, db = y.b - x.b;
    if(da >= 0 && db <= 0 || da <= 0 && db >= 0)
    return LONG_LONG_MAX;
    return (da > 0 ? (da + db - 1) / da : (-da - db + 1) / (-da));
}

void pushup(int x){
    s[x] = (s[x << 1] < s[x << 1 | 1] ? s[x << 1 | 1] : s[x << 1]);	//这里这样写是减小常数
    intr[x] = min({intr[x << 1], intr[x << 1 | 1], inter(s[x << 1], s[x << 1 | 1])});
    return;
}

rebuild

KTT 所特有的函数,是当 \(intr_i\) 小于等于 \(0\) 时使用的,操作也很简单,其实就是 pushdown 一下,然后如果有儿子的 \(intr\) 因为这一下 pushdown 而小于等于 \(0\) 了,再让儿子 rebuild 一下,最后给自己 pushup 一下。

如果某些一次操作之后,某个节点的 \(intr\) 有可能小于等于 \(0\),就调用它。

void rebuild(int x){
    if(intr[x] > 0)
    return;
    pushdown(x);
    rebuild(x << 1);
    rebuild(x << 1 | 1);
    pushup(x);
    return;
}

revise

单点修改。

void revise(int p, node k, int x = 1, int nl = 1, int nr = m){
    if(nl == nr){
        s[x] = k;
        return;
    }
    pushdown(x);
    int mid = nl + nr >> 1;
    if(mid >= p)
    revise(p, k, x << 1, nl, mid);
    else
    revise(p, k, x << 1 | 1, mid + 1, nr);
    pushup(x);
    return;
}

add

区间修改,像常规线段树一样,更改一大段时和 pushdown 很像,但是,普通的 pushdown(x) 之后,\(x\) 的儿子们的 \(intr\) 必然不会小于 \(0\)(因为 \(x\)\(intr\) 比他的儿子们都小,如果他的儿子们的 \(intr\) 小于等于 \(0\),他早就 rebuild 了),而区间修改会让 \(intr\) 小于等于 \(0\),所以要 rebuild 一下。

void add(int l, int r, int k, int x = 1, int nl = 1, int nr = m){
    if(nl >= l && nr <= r){
        s[x].b += s[x].a * k;
        intr[x] -= k;
        lz[x] += k;
        rebuild(x);
        return;
    }
    pushdown(x);
    int mid = nl + nr >> 1;
    if(mid >= l)
    add(l, r, k, x << 1, nl, mid);
    if(mid + 1 <= r)
    add(l, r, k, x << 1 | 1, mid + 1, nr);
    pushup(x);
    return;
}

query

区间查询,直接查询就行,记得及时 pushdown 和 pushup。

int query(int l, int r, int x = 1, int nl = 1, int nr = m){
    if(nl >= l && nr <= r)
    return s[x].b;
    pushdown(x);
    int mid = nl + nr >> 1;
    int ans = 0;
    if(mid >= l)
    ans = max(ans, query(l, r, x << 1, nl, mid));
    if(mid + 1 <= r)
    ans = max(ans, query(l, r, x << 1 | 1, mid + 1, nr));
    pushup(x);
    return ans;
}

build

建树。

void build(int x = 1, int nl = 1, int nr = m){
    lz[x] = 0;
    s[x] = {LONG_LONG_MIN, LONG_LONG_MIN};
    intr[x] = LONG_LONG_MAX;
    if(nl == nr)
    return;
    int mid = nl + nr >> 1;
    build(x << 1, nl, mid);
    build(x << 1 | 1, mid + 1, nr);
    return;
}

更多作用

其实根据李白天的意思,KTT 主要解决的是下述问题,是一个经典问题,他以更优的复杂度解决,但我认为下述问题其实是上述问题的扩展,类似线段树维护乘法和维护加法的关系。

下文称上述问题为常规问题,下述问题为扩展问题。

给定一个整数序列,支持两种操作:

  • \(1 \ \ \ l \ \ \ r \ \ \ k\):表示给区间 \(l\)\(r\) 中每个数加上 \(k\)其中 \(k > 0\)
  • \(2 \ \ \ l \ \ \ r\):表示查询区间 \(l\)\(r\)最大子段和(可以为空)。

限制与常规问题类似。

首先考虑不带修,我们线段树上的每个节点可以维护:

  • \(lma\):最大前缀和
  • \(rma\):最大后缀和
  • \(sum\):节点表示的这段的子段和
  • \(totma\):最大子段和

考虑这些信息的转移:

  • \(lma = \max(ls.lma, ls.sum + rs.lma)\)
  • \(rma = \max(re.rma, rs.sum + ls.rma)\)
  • \(sum = ls.sum + rs.sum\)
  • \(totma = \max(ls.totma, rs.totma, ls.rma + rs.lma)\)

假设我们这些 max 取的位置都不变,那么如果这段区间更新了,假设更新涉及到的长度是 \(l\),区间加的值是 \(x\),这个信息原本维护的值是 \(s\),那么其实就是 \(s = s + l\times x\),在 \(l\) 确定的情况下,容易发现这是一个一次函数,而 \(l\) 更改时,一定是有地方进行了交替,可以简单的考虑到,所以定义交替阈值 \(intr\),这里的 \(intr\) 要考虑很多,分别是 \(lma\) 的两个方案的交替、\(rma\) 的两个方案的交替和 \(totma\) 的三个方案的交替。

Show Code
#include
#define int long long
#define fi first
#define se second
#define mp make_pair
using namespace std;
auto mread = [](){int x;scanf("%lld", &x);return x;};
const int N = 4e5 + 5;
int n = mread(), q = mread(), p[N];
struct line{
    int a, b;
    friend line operator +(line a, line b){
        return {a.a + b.a, a.b + b.b};
    }
};
pair max2(line a, line b){
    if(a.a < b.a || a.a == b.a && a.b < b.b)
    swap(a,b);
	if(a.b >= b.b)
    return mp(a, LONG_LONG_MAX);
	return mp(b, (b.b - a.b) / (a.a - b.a));
}
struct node{
    line lma, rma, sum, totma;
    int intr;
    friend node operator +(node a, node b){
        node t;
		pair tmp;
		t.intr = min(a.intr, b.intr);
		tmp = max2(a.lma, b.lma + a.sum);
		t.lma = tmp.fi;
        t.intr = min(t.intr, tmp.se);
		tmp = max2(b.rma, a.rma + b.sum);
		t.rma = tmp.fi;
        t.intr = min(t.intr, tmp.se);
		tmp = max2(a.totma, b.totma);
		t.intr = min(t.intr, tmp.se);
		tmp = max2(tmp.fi, a.rma + b.lma);
		t.totma = tmp.fi;
        t.intr = min(t.intr, tmp.se);
		t.sum = a.sum + b.sum;
        return t;
    }
};
struct ktt{
    node a[N << 2];
    int lz[N << 2];
    void pushup(int x){
        a[x] = a[x << 1] + a[x << 1 | 1];
    }
    void push(int x, int v){
        lz[x] += v;
        a[x].intr -= v;
        a[x].lma.b += a[x].lma.a * v;
        a[x].rma.b += a[x].rma.a * v;
        a[x].sum.b += a[x].sum.a * v;
        a[x].totma.b += a[x].totma.a * v;
    }
    void pushdown(int x){
        for(int y : {x << 1, x << 1 | 1}){
            push(y, lz[x]);
        }
        lz[x] = 0;
    }
    void build(int x = 1, int l = 1, int r = n){
        lz[x] = 0;
        if(l == r){
            a[x].intr = LONG_LONG_MAX;
            a[x].lma = a[x].rma = a[x].totma = a[x].sum = {1, p[l]};
            return;
        }
        int mid = l + r >> 1;
        build(x << 1, l, mid);
        build(x << 1 | 1, mid + 1, r);
        pushup(x);
        return;
    }
    void rebuild(int x){
        if(a[x].intr >= 0)
        return;
        pushdown(x);
        rebuild(x << 1);
        rebuild(x << 1 | 1);
        pushup(x);
        return;
    }
    void add(int l, int r, int k, int x = 1, int nl = 1, int nr = n){
        if(nl >= l && nr <= r){
            push(x, k);
            rebuild(x);
            return;
        }
        int mid = nl + nr >> 1;
        pushdown(x);
        if(mid >= l)
        add(l, r, k, x << 1, nl, mid);
        if(mid + 1 <= r)
        add(l, r, k, x << 1 | 1, mid + 1, nr);
        pushup(x);
        return;
    }
    node query(int l, int r, int x = 1, int nl = 1, int nr = n){
        if(nl >= l && nr <= r){
            return a[x];
        }
        int mid = nl + nr >> 1, ans = 0;
        pushdown(x);
        if(!(mid >= l)){
            auto tmp = query(l, r, x << 1 | 1, mid + 1, nr);
            pushup(x);
            return tmp;
        }
        if(r <= mid)
        return query(l, r, x << 1, nl, mid);
		if(l > mid)
        return query(l, r, x << 1 | 1, mid + 1, nr);
        return query(l, r, x << 1, nl, mid) + query(l, r, x << 1 | 1, mid + 1, nr);
    }
}T;
signed main(){
    for(int i = 1; i <= n; i ++){
        p[i] = mread();
    }
    T.build();
    while(q --){
        int op = mread();
        if(op == 1){
            int l = mread(), r = mread(), k = mread();
            T.add(l, r, k);
        }
        else{
            int l = mread(), r = mread();
            printf("%lld\n", max(0ll, T.query(l, r).totma.b));
        }
    }
    return 0;
}

复杂度

很复杂,原文好几页都在写这个,真的看不懂,大概是 \(O((n + q) log^2 n)\)\(O((n + q) log^3 n)\),一般情况下跑不到 \(3log\),看作 \(2log\) 即可。下述为作者较为感性的复杂度分析。

ktt 与普通线段树的最大区别,就是 \(intr\) 和 rebuild 这两个东西,他与普通线段树的复杂度差距也因为这两个东西而存在。

通过代码我们发现,每次 rebuild 都是在一个节点不会交替的时候终止,如果会交替,就会继续往子树递归,我们假设一个节点交替一次耗费一单位的时间,由于每个位置存储的值是一个一次函数,所以长度为 \(l\) 的段最多交替 \(l - 1\) 次,所以交替额外造成的时间复杂度是 \(O(n \log n)\) 的。

但可见,上述分析和论文的结果大相径庭 qwq,不过我们只要把他的复杂度当作 \(O(n \log^2 n)\) 即可。

题目们

常规问题模板(作者自己出的)

The Third Grace

EI 的第六分块

[Ynoi2015] 世上最幸福的女孩

posted @ 2024-05-27 20:11  cndark_moon  阅读(2792)  评论(5)    收藏  举报