线段树维护矩阵
对于一些只有单点修改且维护的信息具有递推性的题目,由于运算具有结合律,可以将两区间合并写成矩阵乘法的形式,省去一些麻烦的讨论。
前置知识:广义矩阵乘法
对于一个 \(n\times m\) 的矩阵 \(A\) 和一个 \(m\times t\) 的矩阵 \(B\),定义广义矩阵乘法:
若 \(\otimes\) 对 \(\oplus\) 满足分配律,即 \((a\oplus b)\otimes c=(a\otimes c) \oplus (b\otimes c)\),则新定义的广义矩阵乘法具有结合律。
常见的例子有 \(\otimes=\times,\oplus=+\) 和 \(\otimes =+,\oplus = \min/\max\) 等。
广义矩阵乘法的单位矩阵主对角线上的元素是 \(\otimes\) 的单位元,其他位置是 \(\oplus\) 的单位元。
序列问题
题意:
有一个括号串 \(S\),一开始 \(S\) 中只包含一对括号(即初始的 \(S\) 为
()),接下来有 \(n\) 个操作,操作分为三种:
在当前 \(S\) 的末尾加一对括号(即 \(S\) 变为
S());在当前 \(S\) 的最外面加一对括号(即 \(S\) 变为
(S));取消第 \(x\) 个操作,即去除第 \(x\) 个操作造成过的一切影响(例如,如果第 \(x\) 个操作也是取消操作,且取消了第 \(y\) 个操作,那么当前操作的实质就是恢复了第 \(y\) 个操作的作用效果)。
每次操作后,你需要输出 \(S\) 的能够括号匹配的非空子串(子串要求连续)个数。
一个括号串能够括号匹配,当且仅当其左右括号数量相等,且任意一个前缀中左括号数量不少于右括号数量。
\(1\le n\le 2\times 10^5\)。
如果没有撤销,只需要维护合法子串数 \(ans\) 和 合法后缀数 \(suf\),则操作 1 就是 \(ans\leftarrow ans+suf+1,suf\leftarrow suf+1\),操作 2 就是 \(ans\leftarrow ans+1,suf\leftarrow1\)。
现在有了撤销,考虑用矩阵描述操作。
操作 1:
操作 2:
这里的矩阵乘法就是传统的矩阵乘法。
\(S\) 初始为 (),所以初始答案矩阵显然是 \(\begin{bmatrix}1&1&1\end{bmatrix}\),答案就是将初始矩阵按顺序乘上若干个修改矩阵后,第一个元素的值。
现在考虑撤销:撤销一个修改就是把这个修改矩阵改成单位矩阵;撤销对一个修改的撤销就是把这个矩阵再改回对应的修改矩阵,以此类推……
我们发现有了矩阵这个神奇工具后撤销就变得平凡了,由于矩阵乘法具有结合律,所以可以用线段树维护,时间复杂度 \(O(n\log n)\)。
#include<bits/stdc++.h>
#define endl '\n'
#define rep(i, s, e) for(int i = s, i##E = e; i <= i##E; ++i)
#define per(i, s, e) for(int i = s, i##E = e; i >= i##E; --i)
#define F first
#define S second
#define int ll
#define gmin(x, y) (x = min(x, y))
#define gmax(x, y) (x = max(x, y))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double f128;
typedef pair<int, int> pii;
constexpr int N = 2e5 + 5;
struct matrix {
int a[3][3];
matrix() { memset(a, 0, sizeof a); }
matrix(vector<int> v) {
rep(i, 0, 2) rep(j, 0, 2) a[i][j] = v[i * 3 + j];
}
int *operator[](int x) { return a[x]; }
matrix operator*(matrix &b) const {
matrix c;
rep(i, 0, 2) rep(k, 0, 2) rep(j, 0, 2)
c[i][j] += a[i][k] * b[k][j];
return c;
}
};
const matrix uni({1, 0, 0, 0, 1, 0, 0, 0, 1}),
op1({1, 0, 0, 1, 1, 0, 1, 1, 1}),
op2({1, 0, 0, 0, 0, 0, 1, 1, 1});
matrix nd[N * 4];
#define ls (p << 1)
#define rs (p << 1 | 1)
void build(int p, int l, int r) {
if(l == r) return nd[p] = uni, void();
int mid = (l + r) / 2;
build(ls, l, mid);
build(rs, mid + 1, r);
nd[p] = nd[ls] * nd[rs];
}
void modify(int p, int l, int r, int loc, const matrix &op) {
if(l == r) return nd[p] = op, void();
int mid = (l + r) / 2;
if(loc <= mid) modify(ls, l, mid, loc, op);
else modify(rs, mid + 1, r, loc, op);
nd[p] = nd[ls] * nd[rs];
}
int n, p[N], op[N];
signed main() {
#ifdef ONLINE_JUDGE
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
#endif
cin >> n;
build(1, 1, n);
rep(i, 1, n) {
cin >> op[i];
if(op[i] == 1) modify(1, 1, n, i, op1), p[i] = i;
else if(op[i] == 2) modify(1, 1, n, i, op2), p[i] = i;
else {
int x; cin >> x;
if(p[x] > 0) modify(1, 1, n, p[x], uni), p[i] = -p[x];
else if(op[-p[x]] == 1) modify(1, 1, n, -p[x], op1), p[i] = -p[x];
else modify(1, 1, n, -p[x], op2), p[i] = -p[x];
}
cout << nd[1][0][0] + nd[1][1][0] + nd[1][2][0] << endl;
}
return 0;
}
树上问题
例题:Luogu P4719 【模板】"动态 DP"&动态树分治
题意:
给定一棵 \(n\) 个点的树,点带点权。
有 \(m\) 次操作,每次操作给定 \(x,y\),表示修改点 \(x\) 的权值为 \(y\)。
你需要在每次操作之后求出这棵树的最大权独立集的权值大小。
\(1\le n,m\le10^5\),任意时刻点权 \(\in[-100,100]\)。
首先还是考虑没有修改的情况,也就是没有上司的舞会。取 \(1\) 号点为根,设 \(f(u,0)\) 表示只考虑 \(u\) 的子树时,不选择节点 \(u\) 的最大权值,\(f(u,1)\) 表示选择 \(u\) 的最大权值,则有:(其中 \(son(u)\) 表示 \(u\) 的所有儿子的集合,\(a_u\) 是 \(u\) 的权值)
答案就是 \(\max(f(1,0), f(1,1))\)。
现在带上修改,如果每次修改完都暴力重新 DP,时间复杂度 \(O(nq)\),TLE します。
我们发现,每次修改只会更改一条路径,这给了我们用树剖的自信。考虑结合上一题的思路,用树链剖分把树上问题转化为序列问题,然后用矩阵来描述转移,套上线段树求解。
然而我们尴尬地发现这个转移带一个 \(\sum\),这导致我们无法用一个较小的矩阵直接描述转移。考虑利用树链剖分分出的“轻重儿子”概念,定义 \(g(u,0)\) 和 \(g(u,1)\) 分别表示只考虑节点 \(\boldsymbol u\) 和它的所有轻子树时,不选与选节点 \(u\) 的最大权值。那么就有:(这里的 \(v\) 指 \(u\) 的重儿子)
于是我们消掉了 \(f\) 的转移上的 \(\sum\),但是这仍然不能用传统矩阵乘法,于是我们定义一种广义的矩阵乘法,令 \(\oplus=\max,\otimes=+\),即 \(C=A\times B\) 当且仅当 \(C_{i,j}=\max\limits_{k=1}^{m}\{A_{i,k}+B_{k,j}\}\)(\(m\) 表示矩阵大小)。注意到这个乘法满足结合律是因为加法对 \(\max\) 满足分配率,即 \(\max(a,b)+c=\max(a+c,b+c)\),我们可以据此改写转移方程:
于是我们可以构造转移矩阵:(这里用的是我们新定义的广义矩阵乘法)(状态矩阵的横竖会影响实现,事实上横竖都能做)
我们在每个节点维护转移矩阵。接下来考虑修改,把节点 \(u\) 的权值改为 \(w\)。
在 \(u\) 所在的这条重链上,因为其他节点的轻儿子都不包含 \(u\),所以被更改的只有 \(g(u,1)\)。更改完这条链后,如果已经到达根(\(top(u)=1\))就退出,否则根据 \(f(top(u))\) 更改 \(g(fa(top(u)))\),向上递归。
如果要快速获得一个点 \(u\) 的 DP 值,只需要找到它所在的重链的底部节点 \(x\),查询 \(u-x\) 的路径上的总转移矩阵,与 \(v\) 点的状态矩阵相乘即可。因为 \(x\) 一定是叶子,所以它的的状态一定是 \(\begin{bmatrix}0\\a_x\end{bmatrix}\)。
由于叶节点的转移没有意义,我们可以直接让叶节点的转移是 \(\begin{bmatrix}0&\rm{void}\\a_x&\rm{void}\end{bmatrix}\),其中 \(\rm{void}\) 没有意义,可以是任何值,这样就不用手动乘 \(\begin{bmatrix}0\\a_x\end{bmatrix}\) 了。
时间复杂度 \(O(n\log^2n)\),用一个叫全局平衡二叉树的东西可以优化到 \(O(n\log n)\),但是我不会。
实现时要注意,更改时不能用全局定义的 \(f\) 数组,因为那个 \(f\) 不是实时更新的。
#include<bits/stdc++.h>
#define endl '\n'
#define rep(i, s, e) for(int i = s, i##E = e; i <= i##E; ++i)
#define per(i, s, e) for(int i = s, i##E = e; i >= i##E; --i)
#define F first
#define S second
#define gmin(x, y) (x = min(x, y))
#define gmax(x, y) (x = max(x, y))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double f128;
typedef pair<int, int> pii;
constexpr int N = 1e5 + 5, inf = 0x3f3f3f3f;
struct matrix {
int a[2][2];
int* operator[](int x) { return a[x]; }
matrix operator*(matrix b) const {
matrix c;
c[0][0] = max(a[0][0] + b[0][0], a[0][1] + b[1][0]);
c[0][1] = max(a[0][0] + b[0][1], a[0][1] + b[1][1]);
c[1][0] = max(a[1][0] + b[0][0], a[1][1] + b[1][0]);
c[1][1] = max(a[1][0] + b[0][1], a[1][1] + b[1][1]);
return c;
}
};
int n, m, a[N], f[N][2], g[N][2];
vector<int> to[N];
int fa[N], dep[N], sz[N], hs[N], tp[N], ed[N], dfn[N], rnk[N];
void dfs1(int u, int f) {
fa[u] = f, dep[u] = dep[f] + 1, sz[u] = 1;
for(int v : to[u]) if(v != f) {
dfs1(v, u);
sz[u] += sz[v];
if(sz[v] > sz[hs[u]]) hs[u] = v;
}
}
int dfs2(int u, int fa, int top) {
static int ind;
tp[u] = top, dfn[u] = ++ind, rnk[ind] = u;
g[u][1] = a[u];
if(hs[u]) ed[u] = dfs2(hs[u], u, top);
else return f[u][1] = a[u], ed[u] = dfn[u];
for(int v : to[u]) if(v != fa && v != hs[u]) {
dfs2(v, u, v);
g[u][0] += max(f[v][0], f[v][1]);
g[u][1] += f[v][0];
}
f[u][0] = g[u][0] + max(f[hs[u]][0], f[hs[u]][1]);
f[u][1] = g[u][1] + f[hs[u]][0];
return ed[u];
}
matrix nd[N * 4];
#define ls (p << 1)
#define rs (p << 1 | 1)
void assign(int p, int u) {
if(!hs[u]) {
nd[p][1][0] = a[u];
}
else {
nd[p][0][0] = nd[p][0][1] = g[u][0];
nd[p][1][0] = g[u][1];
nd[p][1][1] = -inf;
}
}
void build(int p, int l, int r) {
if(l == r) return assign(p, rnk[l]);
int mid = (l + r) / 2;
build(ls, l, mid);
build(rs, mid + 1, r);
nd[p] = nd[ls] * nd[rs];
}
void update(int p, int l, int r, int loc) {
if(l == r) return assign(p, rnk[l]);
int mid = (l + r) / 2;
if(loc <= mid) update(ls, l, mid, loc);
else update(rs, mid + 1, r, loc);
nd[p] = nd[ls] * nd[rs];
}
matrix query(int p, int l, int r, int ql, int qr) {
if(ql <= l && r <= qr) return nd[p];
int mid = (l + r) / 2;
if(qr <= mid) return query(ls, l, mid, ql, qr);
if(ql > mid) return query(rs, mid + 1, r, ql, qr);
return query(ls, l, mid, ql, qr) * query(rs, mid + 1, r, ql, qr);
}
void modify(int u, int w) {
g[u][1] += w - a[u];
a[u] = w;
matrix o, p;
while(1) {
if(tp[u] != 1) o = query(1, 1, n, dfn[tp[u]], ed[u]);
update(1, 1, n, dfn[u]);
if(tp[u] == 1) break;
p = query(1, 1, n, dfn[tp[u]], ed[u]);
u = fa[tp[u]];
g[u][0] += max(p[0][0], p[1][0]) - max(o[0][0], o[1][0]);
g[u][1] += p[0][0] - o[0][0];
}
}
signed main() {
#ifdef ONLINE_JUDGE
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
#endif
cin >> n >> m;
rep(i, 1, n) cin >> a[i];
rep(i, 1, n - 1) {
int u, v; cin >> u >> v;
to[u].push_back(v);
to[v].push_back(u);
}
dfs1(1, 0);
dfs2(1, 0, 1);
build(1, 1, n);
while(m--) {
int u, w; cin >> u >> w;
modify(u, w);
auto o = query(1, 1, n, 1, ed[1]);
cout << max(o[0][0], o[1][0]) << endl;
}
return 0;
}
练习题:Luogu P8820 [CSP-S 2022] 数据传输
形式化题意:
给定一棵 \(n\) 个节点的树,和一个常数 \(k\),树上每个节点 \(i\) 都有权值 \(v_i\)。
定义 \(\operatorname{dis}(i,j)\) 表示树上 \(i,j\) 两点间简单路径的边数。
有 \(Q\) 次询问,每次询问给定两个节点 \(s,t(s\neq t)\),你需要找出一个长为 \(m\) 的序列 \(c\),满足 \(\forall i\in[1,m],c_i\in[1,n]\) 且 \(c_1=s,c_m=t\), \(\forall i\in[1,m-1],\operatorname{dis}(c_i,c_{i+1})\le k\),使得 \(\sum_{i=1}^mv_{c_i}\) 最小,输出这个最小值。
\(1\le n,Q\le2\times10^5,1\le v_i\le10^9,1\le k\le 3\)。
其实就是从 \(s\) 开始,每次最多走 \(k\) 步,走到 \(t\),最小化经过的的点权和。
注意到 \(k\) 很小,设 \(\operatorname{path}(i,j)\) 表示 \(i,j\) 间简单路径上的点的集合,考虑分类讨论 \(k\):
- \(k=1\) 时,就是简单路径点权和,容易实现
- \(k=2\) 时,如果存在 \(i\in[1,m],c_i\not\in\operatorname{path}(s,t)\),则必然存在\(j\in[i+1,m], c_j\in\operatorname{path}(s,t)\),因为原图是树且 \(k=2\),所以 \(\operatorname{dis}(c_i,c_j)\le 2\)(这里可以画个图理解一下)所以一定可以直接从 \(c_{i-1}\) 走到 \(c_j\),又因为 \(v>0\),所以不走 \(c_{i}\dots c_{j-1}\) 一定比走这些点要优,所以 \(\forall i\in[1,m],c_i\in\operatorname{path}(s,t)\),于是可以简单 DP \(f(i)=\max(f(i-1),f(i-2))+v_{c_i}\),令广义矩阵乘法 \(\oplus=\min,\otimes=+\),容易构造转移矩阵
- \(k=3\) 时,沿用 \(k=2\) 时的证明不难发现,如果存在 \(c_i\not\in\operatorname{path}(s,t)\),则必然有 \(\min_{j\in\operatorname{path}(s,t)}\operatorname{dis}(c_i,j)=1\),这种情况也是容易考虑的,设 \(f(i,j)\) 表示到达与 \(c_i\) 的距离为 \(j\) 的点所需要的最小代价,设 \(mn_u\) 表示所有与 \(u\) 相邻的点中最小的权值,可以构造转移矩阵:
然后树剖线段树维护就可以了,不过由于这个东西是有顺序的,所以线段树节点上需要同时维护左乘右和右乘左两个矩阵。
时间复杂度 \(O(n\log^2n)\),实现时需要注意树剖跳重链时矩阵乘法的顺序。左乘右和右乘左两种询问分开写可以减小常数。
#include<bits/stdc++.h>
#define endl '\n'
#define rep(i, s, e) for(int i = s, i##E = e; i <= i##E; ++i)
#define per(i, s, e) for(int i = s, i##E = e; i >= i##E; --i)
#define F first
#define S second
#define gmin(x, y) (x = min(x, y))
#define gmax(x, y) (x = max(x, y))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double f128;
typedef pair<int, int> pii;
constexpr int N = 2e5 + 5;
int n, q, k, a[N], mn[N];
vector<int> to[N];
struct matrix {
ll a[3][3];
matrix() { memset(a, 0x3f, sizeof a); }
ll* operator[](int x) { return a[x]; }
matrix operator*(matrix b) const {
matrix c;
rep(i, 0, 2) rep(j, 0, 2)
c[i][j] = min({a[i][0] + b[0][j], a[i][1] + b[1][j], a[i][2] + b[2][j]});
return c;
}
};
int fa[N], dep[N], sz[N], hs[N], top[N], dfn[N], rnk[N];
void dfs1(int u, int f) {
dep[u] = dep[f] + 1;
fa[u] = f;
sz[u] = 1;
for(int v : to[u]) if(v != f) {
dfs1(v, u);
sz[u] += sz[v];
if(sz[v] > sz[hs[u]]) hs[u] = v;
}
}
void dfs2(int u, int f, int tp) {
static int ind;
top[u] = tp;
dfn[u] = ++ind;
rnk[ind] = u;
if(hs[u]) dfs2(hs[u], u, tp);
for(int v : to[u])
if(v != f && v != hs[u])
dfs2(v, u, v);
}
pair<matrix, matrix> nd[N * 4];
#define ls (p << 1)
#define rs (p << 1 | 1)
void assign(int p, int u) {
nd[p].F[1][0] = nd[p].F[2][1] = nd[p].S[1][0] = nd[p].S[2][1] = 0;
nd[p].F[0][0] = nd[p].S[0][0] = a[u];
if(k >= 2) nd[p].F[0][1] = nd[p].S[0][1] = a[u];
if(k == 3)
nd[p].F[0][2] = nd[p].S[0][2] = a[u],
nd[p].F[1][1] = nd[p].S[1][1] = mn[u];
}
void build(int p, int l, int r) {
if(l == r) return assign(p, rnk[l]);
int mid = (l + r) / 2;
build(ls, l, mid);
build(rs, mid + 1, r);
nd[p].F = nd[ls].F * nd[rs].F;
nd[p].S = nd[rs].S * nd[ls].S;
}
matrix qlr(int p, int l, int r, int ql, int qr) {
if(ql <= l && r <= qr) return nd[p].F;
int mid = (l + r) / 2;
if(qr <= mid) return qlr(ls, l, mid, ql, qr);
if(ql > mid) return qlr(rs, mid + 1, r, ql, qr);
return qlr(ls, l, mid, ql, qr) * qlr(rs, mid + 1, r, ql, qr);
}
matrix qrl(int p, int l, int r, int ql, int qr) {
if(ql <= l && r <= qr) return nd[p].S;
int mid = (l + r) / 2;
if(qr <= mid) return qrl(ls, l, mid, ql, qr);
if(ql > mid) return qrl(rs, mid + 1, r, ql, qr);
return qrl(rs, mid + 1, r, ql, qr) * qrl(ls, l, mid, ql, qr);
}
ll query(int u, int v) {
matrix um, vm;
rep(i, 0, 2) um[i][i] = vm[i][i] = 0;
while(top[u] != top[v]) {
if(dep[top[u]] >= dep[top[v]]) {
um = qlr(1, 1, n, dfn[top[u]], dfn[u]) * um;
u = fa[top[u]];
}
else {
vm = vm * qrl(1, 1, n, dfn[top[v]], dfn[v]);
v = fa[top[v]];
}
}
matrix o;
if(dep[u] >= dep[v]) o = vm * qlr(1, 1, n, dfn[v], dfn[u]) * um;
else o = vm * qrl(1, 1, n, dfn[u], dfn[v]) * um;
return o[0][k - 1];
}
signed main() {
#ifdef ONLINE_JUDGE
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
#endif
cin >> n >> q >> k;
memset(mn, 0x3f, sizeof mn);
rep(i, 1, n) cin >> a[i];
rep(i, 1, n - 1) {
int u, v; cin >> u >> v;
to[u].push_back(v);
to[v].push_back(u);
gmin(mn[u], a[v]);
gmin(mn[v], a[u]);
}
dfs1(1, 0);
dfs2(1, 0, 1);
build(1, 1, n);
while(q--) {
int s, t; cin >> s >> t;
cout << query(s, t) << endl;
}
return 0;
}

浙公网安备 33010602011771号