线段树区间加、区间乘、区间推平、区间取相反数的通用处理办法
首先声明:“通用”并不是万能,只是能维护这些操作下的大多数常见的区间信息。
将数列中的每个元素视为一个一次函数 \(f_i(x)=k_ix+b_i\)。假设数列为 \(a\),则初始化 \(f_i(x)=0x+a_i\)。
区间加、区间乘操作可以视为将区间每个一次函数复合一个一次函数 \(g_j(x)=k_jx+b_j\),其中区间加 \(t_j\) 为 \(g_j(x)=1x+t_j\),区间乘 \(t_j\) 为 \(g_j(x)=t_jx+0\)。
对 \(g_j(f_i(x))\) 进行变形:
\[\begin{aligned}
g_j(f_i(x))&=k_j(k_ix+b_i)+b_j\\
&=k_jk_ix+k_jb_i+b_j
\end{aligned}
\]
得到一个新的一次函数,这个一次函数的斜率和截距可以通过两个函数的斜率和截距很方便地求出。同时根据以上公式,容易知道(在不少题解中都讲不清楚的)加法和乘法的优先级顺序。
区间推平为 \(t_j\) 可以视为复合一次函数 \(g_j(x)=0x+t_j\),区间取相反数可以视为复合一次函数 \(g_j(x)=-x+0\)。
综上,区间加、区间乘、区间推平、区间取相反数等类似操作都可以视为线段树维护一次函数,并进行一次函数复合操作。
更一般地,线段树可以视为维护元素半群和操作半群的数据结构,因此可以进行较为通用的封装。
代码
// Problem: P1253 [yLOI2018] 扶苏的问题
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P1253
// Memory Limit: 512 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
//By: OIer rui_er
#include <bits/stdc++.h>
#define rep(x,y,z) for(ll x=(y);x<=(z);x++)
#define per(x,y,z) for(ll x=(y);x>=(z);x--)
#define debug(format...) fprintf(stderr, format)
#define fileIO(s) do{freopen(s".in","r",stdin);freopen(s".out","w",stdout);}while(false)
#define likely(exp) __builtin_expect(!!(exp), 1)
#define unlikely(exp) __builtin_expect(!!(exp), 0)
using namespace std;
typedef long long ll;
const ll N = 1e6+5;
ll n, m, a[N];
template<typename T> void chkmin(T& x, T y) {if(x > y) x = y;}
template<typename T> void chkmax(T& x, T y) {if(x < y) x = y;}
struct Node {
ll mx, k, b; // lazy tag: kx + b
};
struct SegTree {
Node t[N<<2];
#define lc(u) (u<<1)
#define rc(u) (u<<1|1)
void pushup(ll u) {
t[u].mx = max(t[lc(u)].mx, t[rc(u)].mx);
}
void pushdown(ll u, ll l, ll r) {
ll mid = (l + r) >> 1;
t[lc(u)].mx = t[u].k * t[lc(u)].mx + t[u].b;
t[lc(u)].k = t[u].k * t[lc(u)].k;
t[lc(u)].b = t[u].k * t[lc(u)].b + t[u].b;
t[rc(u)].mx = t[u].k * t[rc(u)].mx + t[u].b;
t[rc(u)].k = t[u].k * t[rc(u)].k;
t[rc(u)].b = t[u].k * t[rc(u)].b + t[u].b;
t[u].k = 1; t[u].b = 0;
}
void build(ll* a, ll u, ll l, ll r) {
t[u].k = 1; t[u].b = 0;
if(l == r) {
t[u].mx = a[l];
return;
}
ll mid = (l + r) >> 1;
build(a, lc(u), l, mid);
build(a, rc(u), mid+1, r);
pushup(u);
}
void modify(ll u, ll l, ll r, ll ql, ll qr, ll k, ll b) {
if(ql <= l && r <= qr) {
t[u].mx = k * t[u].mx + b;
t[u].k = k * t[u].k;
t[u].b = k * t[u].b + b;
return;
}
pushdown(u, l, r);
ll mid = (l + r) >> 1;
if(ql <= mid) modify(lc(u), l, mid, ql, qr, k, b);
if(qr > mid) modify(rc(u), mid+1, r, ql, qr, k, b);
pushup(u);
}
ll query(ll u, ll l, ll r, ll ql, ll qr) {
if(ql <= l && r <= qr) return t[u].mx;
pushdown(u, l, r);
ll mid = (l + r) >> 1, ans = LLONG_MIN;
if(ql <= mid) chkmax(ans, query(lc(u), l, mid, ql, qr));
if(qr > mid) chkmax(ans, query(rc(u), mid+1, r, ql, qr));
pushup(u);
return ans;
}
#undef lc
#undef rc
}sgt;
int main() {
scanf("%lld%lld", &n, &m);
rep(i, 1, n) scanf("%lld", &a[i]);
sgt.build(a, 1, 1, n);
while(m--) {
ll op, x, y, z;
scanf("%lld", &op);
if(op == 1) {
scanf("%lld%lld%lld", &x, &y, &z);
sgt.modify(1, 1, n, x, y, 0, z);
}
else if(op == 2) {
scanf("%lld%lld%lld", &x, &y, &z);
sgt.modify(1, 1, n, x, y, 1, z);
}
else {
scanf("%lld%lld", &x, &y);
printf("%lld\n", sgt.query(1, 1, n, x, y));
}
}
return 0;
}