Loading

线段树历史值学习笔记

(先单开出来,后面准备合并到线段树 trick 里)
(好像合并不了了)

历史和指的是线段树维护的序列 \(a\),我们再开一个序列 \(b\),每次修改 / 查询后进行 \(\forall b_i \leftarrow b_i + a_i\) 操作,\(b\) 称作 \(a\) 的历史和。

历史和一般搭配扫描线使用,多用于二维的问题模型。

做法

做法一:\(c_i = h_i - t \times a_i\)

最简单好写,也最不易推广的做法。

构造 \(c_i = h_i - t \times a_i\),每次 \(t\) 增长 \(1\)\(h_i \leftarrow h_i + a_i\),但是 \(c_i\)\(a_i\) 不修改的情况下不会变化。
对于 \(a_i \leftarrow a_i + v\)\(c_i \leftarrow c_i - tv\) 即可。

查询历史和,即 \(\sum c_i + t \sum a_i\),维护 \(a,c\) 即可。

做法二:矩阵

据说,矩阵乘法在这种问题中是万能的。

发现难点在于更新历史和是全局更新,但是朴素方法手动更新(维护 \(hsum, sum\)\(hsum \leftarrow hsum + sum\))复杂度不对。
这本质上是因为没有将加法和更新历史和的操作都拼合成一种标记(满足结合律,可以快速合并的标记)。

矩阵乘法及广义乘法满足了我们的需求。

我们可以维护 \(\begin{bmatrix} hsum, sum, len \end{bmatrix}\) 作为线段树节点的信息,并另外维护一个 \(3 \times 3\) 的矩阵标记。

对于区间加 \(v\),要实现 \(sum \leftarrow sum + v \times len\)
区间加矩阵:

\[\begin{bmatrix}1 & 0 & 0\\0 & 1 & 0\\0 & v & 1\\\end{bmatrix} \]

对于更新历史和,要实现 \(hsum \leftarrow hsum + sum\)
更新历史和矩阵:

\[\begin{bmatrix}1 & 0 & 0\\1 & 1 & 0\\0 & 0 & 1\end{bmatrix} \]

区间覆盖矩阵也能做,只需要把区间加矩阵的 \((2,2)\) 位置改成 \(0\),即 \(sum \leftarrow v \times len\) 即可。

弱点在于常数问题,当然可以手动拆开转移,只保留会改变的部分。

做法三:标记队列

部分参考 command_block 的博客

在线段树标记的下推机制中,某个点存有标记,表示整一棵子树在标记存在的时间内都未曾更新。
于是,问题的核心就在于分析单个节点上停留的标记的影响。

在非历史值问题中,我们只关注当下的标记,所以我们永远合并标记,便于存储。
但是在历史值问题中,我们需要考虑历史上存储过的标记的依次作用和当前的合并结果。

为了便于理解,我们暂时不考虑实现的可行性,我们假定每个节点维护了整个标记序列,以时间为顺序。
线段树上的每个节点维护一个类似“标记队列”,队列每一项是形如 \(+v\) 的加法操作或是更新历史和操作。

暂时,我们的线段树节点应该要维护一下信息:区间和 \(sum\),历史和 \(hsum\),加法标记 \(add\)

对于队列里面的操作,会对节点信息造成一下影响。

  • \(+ v\) 操作:\(sum \leftarrow sum + v \times len,add \leftarrow add + v\)
  • 更新历史和操作:\(hsum \leftarrow hsum + sum\)

那么我们考虑将父亲的队列合并到儿子的队列上是怎样的。
请注意,线段树上,父亲的标记合并到儿子的标记上时,儿子标记的时间是更靠前的,这对于不满足交换律的合并运算(如矩阵乘法)是至关重要的。

设将队列 \(2\) 合并到队列 \(1\) 上,队列 \(1\) 的时间靠前。
以下 \(sum, hsum\) 指的是节点存储的信息。\(add\) 是队列中 \(+v\) 操作的合并结果。

  • \(add_1 \leftarrow add_1 + add_2\),直接继承。

  • \(sum_1 \leftarrow sum_1 + len_1 \times add_2\)

  • 考虑 \(hsum_1\) 的变化:

    在加入队列 \(2\) 的若干操作后,原先的 \(sum_1\) 指的是队列 \(1\) 的合并结果,它会在队列 \(2\) 的每个“更新历史和”操作中用到,造成系数为 \(1\) 的贡献。
    \(upd\) 为一个队列中,“更新历史和”操作的次数,则这一部分贡献为 \(sum_1 \times upd_2\)

    还有一部分贡献来自于队列 \(2\) 中的 \(+ v\) 操作,它们会在队列 \(2\) 的每个“更新历史和”操作时作用在 \(hsum\) 上,因为是加法标记,造成系数为 \(len_1\) 的贡献。
    那么我们需要知道一个队列里,「每次“更新历史和”时的 \(add\)」 的和,记为 \(hadd\)
    这部分的贡献为 \(hadd_2 \times len_1\)

    综上,\(hsum_1 \leftarrow hsum_1 + sum_1 \times upd_2 + hadd_2 \times len_1\)

  • \(upd_1 \leftarrow upd_1 + upd_2\)

  • \(hadd_1 \leftarrow hadd_1 + add_1 \times upd_2 + hadd_2\)

    首先 \(hadd_1, hadd_2\) 造成贡献是显然的,都是在各自的队列时间范围内的贡献。

    还有队列 \(1\)\(+v\) 操作的合并结果 \(add_1\) 在队列 \(2\) 的时间范围内造成的贡献,是每一次队列 \(2\) 中“更新历史和”操作时体现的,故系数为 \(upd_2\)

于是,我们发现我们只需要刻画 \(add, upd, hadd\) 即可刻画出一整个队列,加上节点本身的 \(hsum, sum\),维护这些标记即可。

同时你会发现,标记队列不好做区间覆盖,这也是它的局限性。

例题

CF1834D

题解

P8868 [NOIP 2022] 比赛

本题使用标记队列法来解是更简单的。

请阅读并充分理解标记队列法,并充分理解 CF1824D 的扫描线做法,然后阅读此题解。

题意

两个序列 \(a,b\) 长度均为 \(n\)\(q\) 次询问,给出一个区间 \([l,r]\),求:

\[\sum \limits _ {l' = l} ^ r \sum \limits _ {r' = l'} ^ r \left( \max \limits _{i = l'} ^ {r'} a_i \right ) \times \left ( \max \limits _{i = l'} ^ {r'} b_i\right ) \]

人话是子区间的 \(a,b\) 极值的乘积的和。

问题分析

会了 CF1824D 之后,你应该很容易地知道这题应该使用扫描线,并且有能力预见到是扫描线配合线段树历史和的 trick。

离线询问,扫描线右端点,记录 \(f_i\) 表示对于当前右端点,左端点为 \(i\) 时的答案。

我们依旧是考虑右端点 \(j\) 移动时的改变,因为需要维护 \(a,b\) 的最大值,不难想到要维护单调栈(单调递减单调栈)
每次将两个单调栈(分别维护 \(a,b\))中 \(\lt a_j\) 的全部弹出,那么栈顶到 \(j\) 的位置全部更新 \(a\)\(b\)

于是数据结构要实现:

  • \(a\) 区间加,对 \(b\) 区间加(或者看成区间覆盖,但是标记队列不好做区间覆盖,单调栈的性质让我们可以改成区间加)。

  • 查询区间 \(a \times b\) 的历史和。

数据结构

知道了标记队列的做法后,这题就是标记队列进行简单更改后得到的。

线段树维护:

  • \(sab\),表示 \(a \times b\) 的区间和。

  • \(sa, sb\) 分别表示区间 \(a,b\) 的和。

  • \(hsab\),表示 \(sab\) 的历史和。

标记队列应当包括(但不限于):

  • \(adda, addb\),分别表示 \(a, b\) 的加法标记。

  • 更新历史和标记。

其影响:

  • \(a + v\) 操作,\(sab \leftarrow sab + v \times sb, sa \leftarrow sa + v \times len, adda \leftarrow adda + v\)

  • \(b + v\) 操作,\(sab \leftarrow sab + v \times sa, sb \leftarrow sb + v \times len, addb \leftarrow addb + v\)

  • 更新历史和操作:\(hsab \leftarrow hsab + sab\)

合并队列,依旧是队列 \(2\) 合并到队列 \(1\)

以下是定义:

    ull hsab, // sum a * b 的历史和
        sab, // sum a * b
        sa, // suma
        sb, // sumb
        len, // 区间长度

        hab, // a * b 每次操作的历史和
        ha, // a 每次操作历史和
        hb, // b 每次操作历史和
        upd, // 更新历史和操作次数
        adda, // a 加法标记
        addb; // b 加法标记

转移:完全就是板子题式子的稍微变种,只是注意分 \(a,b\) 讨论即可。

\[\begin{aligned} hab_1 &\leftarrow hab_1 + adda_1 \times addb_1 \times upd_2 + adda_1 \times hb_2 + addb_1 \times ha_2 + hab_2, \\[6pt] ha_1 &\leftarrow ha_1 + adda_1 \times upd_2 + ha_2, \\[6pt] hb_1 &\leftarrow hb_1 + addb_1 \times upd_2 + hb_2, \\[6pt] sab_1 &\leftarrow sab_1 + sa_1 \times addb_2 + sb_1 \times adda_2 + addb_2 \times adda_2 \times len_1, \\[6pt] sa_1 &\leftarrow sa_1 + adda_2 \times len_1, \\[6pt] sb_1 &\leftarrow sb_1 + addb_2 \times len_1, \\[6pt] upd_1 &\leftarrow upd_1 + upd_2, \\[6pt] adda_1 &\leftarrow adda_1 + adda_2, \\[6pt] addb_1 &\leftarrow addb_1 + addb_2. \end{aligned} \]

代码

const int N = 3e5 + 5;
int n, q;
ull a[N], b[N], ans[N];

struct node{
    // ull adda,addb,upd,ha,hb,l;
    ull hsab, // sum a * b 的历史和
        sab, // sum a * b
        sa, // suma
        sb, // sumb
        len, // 区间长度

        hab, // a * b 每次操作的历史和
        ha, // a 每次操作历史和
        hb, // b 每次操作历史和
        upd, // 更新历史和操作次数
        adda, // a 加法标记
        addb; // b 加法标记

    node(){
        hsab = sab = sa = sb = len = hab = ha = hb = upd = adda = addb = 0;
    }
} t[N << 2];
    node calc_add_node(bool type, ull v, int len){
        node res;
        if(type == 0) res.adda = v;
        else res.addb = v;
        res.len = len;
        return res;
    }
    node upd_h_node;

#define mid ((l + r) >> 1)
#define ls(x) (x << 1)
#define rs(x) ((x << 1) | 1)
void push_up(int x){
    t[x].sab = t[ls(x)].sab + t[rs(x)].sab;
    t[x].sa = t[ls(x)].sa + t[rs(x)].sa;
    t[x].sb = t[ls(x)].sb + t[rs(x)].sb;
    t[x].hsab = t[ls(x)].hsab + t[rs(x)].hsab;
}
void hard(int x, node v){
    t[x].hsab += t[x].sab * v.upd + t[x].sa * v.hb + t[x].sb * v.ha + v.hab * t[x].len;
	t[x].hab += t[x].adda * t[x].addb * v.upd + t[x].adda * v.hb + t[x].addb * v.ha + v.hab;
	t[x].ha += t[x].adda * v.upd + v.ha;
	t[x].hb += t[x].addb * v.upd + v.hb;
	t[x].sab += t[x].sa * v.addb + t[x].sb * v.adda + v.addb * v.adda * t[x].len;
	t[x].sa += v.adda * t[x].len;
	t[x].sb += v.addb * t[x].len;
	t[x].upd += v.upd;
	t[x].adda += v.adda;
	t[x].addb += v.addb;
}
void push_down(int x){
    hard(ls(x), t[x]);
    hard(rs(x), t[x]);
    t[x].hab = t[x].ha = t[x].hb = t[x].upd = t[x].adda = t[x].addb = 0;
}
void build(int x, int l, int r){
    t[x].len = r - l + 1;
    if(l == r) return;
    build(ls(x), l, mid);
    build(rs(x), mid + 1, r);
    push_up(x);
}
void modify(int x, int l, int r, int ql, int qr, ull v, bool type){ // type : 0 -> a , 1 -> b
    if(ql <= l && r <= qr){
        hard(x, calc_add_node(type, v, r - l + 1));
        return;
    }
    push_down(x);
    if(ql <= mid) modify(ls(x), l, mid, ql, qr, v, type);
    if(qr > mid) modify(rs(x), mid + 1, r, ql, qr, v, type);
    push_up(x);
}
ull query(int x, int l, int r, int ql, int qr){
    if(ql <= l && r <= qr){
        return t[x].hsab;
    }
    push_down(x);
    ull res = 0;
    if(ql <= mid) res += query(ls(x), l, mid, ql, qr);
    if(qr > mid) res += query(rs(x), mid + 1, r, ql, qr);
    return res;
}

struct Query{
    int l, qid;
};
vector<Query> qry[N];

int stk_a[N], stk_b[N], top_a, top_b;
void solve_test_case(){
    int cid = read();
    n = read();

    upd_h_node.upd = 1;

    rep(i, 1, n) a[i] = read();
    rep(i, 1, n) b[i] = read();
    q = read();
    rep(i, 1, q){
        int l = read(), r = read();
        qry[r].push_back({l, i});
    }

    build(1, 1, n);
    top_a = top_b = 1;
    // stk_a[++top_a] = 0, stk_b[++top_b] = 0;
    a[0] = b[0] = n + 1;

    rep(i, 1, n){
        while(a[stk_a[top_a]] < a[i]){
            modify(1, 1, n, stk_a[top_a - 1] + 1, stk_a[top_a], -a[stk_a[top_a]], 0);
            top_a--;
        }   
        modify(1, 1, n, stk_a[top_a] + 1, i, a[i], 0);
		stk_a[++top_a] = i;

        while(b[stk_b[top_b]] < b[i]){
            modify(1, 1, n, stk_b[top_b - 1] + 1, stk_b[top_b], -b[stk_b[top_b]], 1);
            top_b--;
        }
        modify(1, 1, n, stk_b[top_b] + 1, i, b[i], 1);
		stk_b[++top_b] = i;

        hard(1, upd_h_node);

        for(auto [l, qid] : qry[i]){
            ans[qid] = query(1, 1, n, l, i);
        }
    }
    rep(i, 1, q){
        write(ans[i]);
    }
}

SZMS OJ DS竞赛

题意

给你一个长为 \(n\) 的序列 \(a\),给定 \(d\)

对于一个序列 \(S\)

  • 若可以通过给 \(S\) 中加入一些数并排序的方式使得 \(S\) 成为一个公差为 \(d\) 的等差数列,那么 \(f(S) = \min 加入的数的个数\)

  • 否则,\(f(S) = 0\)

\(q\) 次询问,问一个区间 \([l,r]\) 的所有子区间对应序列的 \(f\) 值的和。

做法

分析一个序列可以变成等差数列的条件:

  • 所有数 \(\mod d\) 都相等。

  • 没有重复数字。

套路地扫描线右端点 \(i\),我们发现,能构成等差数列的区间的左端点一定是 \([1,i]\) 的一个后缀,那么我们可以双指针维护一个 \(L\) 表示可能成为等差数列的左端点的最小值。
\(L\) 的移动按照上文的条件进行即可,需要预处理每个数前面最后一个和它数值相同的位置。

线段树要维护的是 \(f\),我们每次把 \(L\) 往右移动时要将这个位置的 \(f\) 清空为 \(0\)

那么如何计算 \(f\)

对于一个序列 \(S\),已知可以变成等差数列,那么:

\[\begin{align*} f(S) &= (\frac {mx - mn} d + 1) - (r - l + 1) \\ &= \frac {mx} d - \frac {mn} d - r + l \end{align*} \]

因为已知 \(mx \mod d = mn \mod d\),所以直接将除法取整即可。

因为有最大值和最小值的要求,我们套路地维护单调栈,每次更改一个区间的 \(f\)

\([l,r]\) 的询问的答案就是 \(i=r\)\([l,r]\) 的历史和。

代码

#define int ll

const int N = 5e5 + 5;
int n, d, q;
int a[N];

struct node{
	ll hsum, sum, len;
	ll add, hadd, upd;
} t[N << 2];
node calc_add_node(int v){
	return node{0, 0, 0, v, 0, 0};
}
node upd_h_node;

#define mid ((l + r) >> 1)
#define ls(x) (x << 1)
#define rs(x) ((x << 1) | 1)

void hard(int x, node v){
	t[x].hsum += t[x].sum * v.upd + t[x].len * v.hadd;
	t[x].sum += t[x].len * v.add;
	t[x].hadd += v.hadd + t[x].add * v.upd;
	t[x].add += v.add;
	t[x].upd += v.upd;
}

void push_up(int x){
	t[x].hsum = t[ls(x)].hsum + t[rs(x)].hsum;
	t[x].sum = t[ls(x)].sum + t[rs(x)].sum;
}
void push_down(int x){
	hard(ls(x), t[x]);
	hard(rs(x), t[x]);
	t[x].add = t[x].hadd = t[x].upd = 0;
}
void build(int x, int l, int r){
	t[x].len = r - l + 1;
	if(l == r) return;
	build(ls(x), l, mid);
	build(rs(x), mid + 1, r);
	push_up(x);
}
void add(int x, int l, int r, int ql, int qr, int v){
	if(ql > qr) return;
	if(ql <= l && r <= qr){
		hard(x, calc_add_node(v));
		return;
	}
	push_down(x);
	if(ql <= mid) add(ls(x), l, mid, ql, qr, v);
	if(qr > mid) add(rs(x), mid + 1, r, ql, qr, v);
	push_up(x);
}
void clear(int x, int l, int r, int p){
	if(l == r){
		t[x].sum = 0;
		return;
	}
	push_down(x);
	if(p <= mid) clear(ls(x), l, mid, p);
	else clear(rs(x), mid + 1, r, p);
	push_up(x);
}
ll query(int x, int l, int r, int ql, int qr){
	if(ql > qr) return 0;
	if(ql <= l && r <= qr){
		return t[x].hsum;
	}
	push_down(x);
	ll res = 0;
	if(ql <= mid) res += query(ls(x), l, mid, ql, qr);
	if(qr > mid) res += query(rs(x), mid + 1, r, ql, qr);
	return res;
}

ll ans[N];
struct Query{
	int l, qid;
};
vector<Query> qry[N];
int pre[N];
map<int, int> pos;
pair<int, int> stk_mx[N]; int top_mx;
pair<int, int> stk_mn[N]; int top_mn;
int L;
void solve_test_case(){
	n = read(), d = read(), q = read();
	rep(i, 1, n){
		a[i] = read();
		pre[i] = pos[a[i]];
		pos[a[i]] = i;
	}
	if(n == 0){
		while(q--){
			puts("0");
		}
		return;
	}
	rep(i, 1, q){
		int l = read(), r = read();
		qry[r].push_back({l, i});
	}
	
	upd_h_node.upd = 1;
	
	build(1, 1, n);
	L = 1;
	
	stk_mx[++top_mx] = {0, 1e7 + 1};
	stk_mn[++top_mn] = {0, 0};
	
	rep(i, 1, n){
		while(L <= i && a[L] % d != a[i] % d){
			clear(1, 1, n, L);
			L++;
		}
		while(L <= i && L <= pre[i]){
			clear(1, 1, n, L);
			L++;
		}
		while(top_mx && stk_mx[top_mx].first >= L && stk_mx[top_mx].second < a[i]){
			add(1, 1, n, max(stk_mx[top_mx - 1].first + 1, L), stk_mx[top_mx].first, (int)(a[i] / d) - (int)(stk_mx[top_mx].second / d));
			top_mx--;
		}
		stk_mx[++top_mx] = {i, a[i]};
		
		while(top_mn && stk_mn[top_mn].first >= L && stk_mn[top_mn].second > a[i]){
			add(1, 1, n, max(stk_mn[top_mn - 1].first + 1, L), stk_mn[top_mn].first, (int)(stk_mn[top_mn].second / d) - (int)(a[i] / d));			
			top_mn--;
		}
		stk_mn[++top_mn] = {i, a[i]};
		
		add(1, 1, n, L, i - 1, -1);
		hard(1, upd_h_node);
		for(Query cur : qry[i]){
			int l = cur.l, qid = cur.qid;
			ans[qid] = query(1, 1, n, l, i);
		}
	}
	rep(i, 1, q){
		write(ans[i]);
	}
}
posted @ 2025-10-18 22:50  lajishift  阅读(36)  评论(5)    收藏  举报