动态 DP
动态 DP
对于某些序列上或树上的 DP 问题,要求修改信息并重新求答案。此时,若每次修改后都重新 DP 比较费时,于是引入动态 DP。
广义矩阵乘法
定义矩阵乘法 \(A \times B = C\) 为:
广义矩阵乘法则考虑修改矩阵元素之间的运算方式,将加法与乘法改为其他运算,并且仍向上面一样定义矩阵乘法,那么只需要满足:
- 乘法对加法满足分配律。
- 加法本身满足交换律和结合律。
如可以将加法改为 \(\min\) / \(\max\) ,乘法改为求和。
序列动态 DP
GSS3 - Can you answer these queries III
单点修改、询问区间最大子段和。
\(n, m \le 5 \times 10^4\)
先考虑没有修改的情况,设 \(f_i\) 表示以 \(a_i\) 结尾的最大子段和,\(g_i\) 表示 \([1, i]\) 的最大子段和,显然有转移:
用 \((\max, +)\) 矩阵乘法将转移式写成以下形式:
那么单点修改只要改一个位置的矩阵 \(A\) ,区间查询只要查这段区间的矩阵乘积即可,时间复杂度 \(O(m \log n)\) 。
#include <bits/stdc++.h>
using namespace std;
const int inf = 0x3f3f3f3f;
const int N = 5e4 + 7;
struct Matrix {
int a[3][3];
inline Matrix() {
memset(a, -inf, sizeof(a));
}
inline void prework(int x) {
a[0][0] = x, a[0][1] = x, a[0][2] = -inf;
a[1][0] = -inf, a[1][1] = 0, a[1][2] = -inf;
a[2][0] = x, a[2][1] = x, a[2][2] = 0;
}
inline Matrix operator * (const Matrix &rhs) {
Matrix res;
for (int i = 0; i < 3; ++i)
for (int j = 0; j < 3; ++j)
for (int k = 0; k < 3; ++k)
res.a[i][k] = max(res.a[i][k], a[i][j] + rhs.a[j][k]);
return res;
}
};
int a[N];
int n, m;
namespace SMT {
Matrix mt[N << 2];
inline int ls(int x) {
return x << 1;
}
inline int rs(int x) {
return x << 1 | 1;
}
void build(int x, int l, int r) {
if (l == r) {
mt[x].prework(a[l]);
return;
}
int mid = (l + r) >> 1;
build(ls(x), l, mid), build(rs(x), mid + 1, r);
mt[x] = mt[ls(x)] * mt[rs(x)];
}
void update(int x, int nl, int nr, int p, int k) {
if (nl == nr) {
mt[x].prework(k);
return;
}
int mid = (nl + nr) >> 1;
if (p <= mid)
update(ls(x), nl, mid, p, k);
else
update(rs(x), mid + 1, nr, p, k);
mt[x] = mt[ls(x)] * mt[rs(x)];
}
Matrix query(int x, int nl, int nr, int l, int r) {
if (l <= nl && nr <= r)
return mt[x];
int mid = (nl + nr) >> 1;
if (r <= mid)
return query(ls(x), nl, mid, l, r);
else if (l > mid)
return query(rs(x), mid + 1, nr, l, r);
else
return query(ls(x), nl, mid, l, r) * query(rs(x), mid + 1, nr, l, r);
}
} // namespace SMT
signed main() {
scanf("%d", &n);
for (int i = 1; i <= n; ++i)
scanf("%d", a + i);
SMT::build(1, 1, n);
scanf("%d", &m);
while (m--) {
int op, x, y;
scanf("%d%d%d", &op, &x, &y);
if (op) {
Matrix ans = SMT::query(1, 1, n, x, y);
printf("%d\n", max(ans.a[0][1], ans.a[2][1]));
} else
SMT::update(1, 1, n, x, y);
}
return 0;
}
P5069 [Ynoi Easy Round 2015] 纵使日薄西山
对于序列 \(a_{1 \sim n}\) ,定义一次操作为:选出序列中最大值的出现位置 \(x\) (有多个取最靠前的),将 \(a_{x - 1}, a_x, a_{x + 1}\) 的值减 \(1\) 。
\(q\) 次操作,每次单点修改一个位置的值,每次修改后求使所有数 \(\le 0\) 的操作次数。
\(n, q \le 10^5\)
考虑将会被操作的位置记为 \(1\) ,反之记为 \(0\) ,并且这个 01 序列具有如下性质:
- 任意两个 \(1\) 不相邻。
- 不存在长度 \(> 2\) 的 \(0\) 的极长连续段。
- 对于每个 \(0\) 的位置 \(x\) ,存在一个相邻的 \(1\) 在 \(a\) 中的值大于该位置在 \(a\) 中的值。
可以发现 01 序列是唯一的,答案即为所有 \(1\) 位置的 \(a\) 之和。
设 \(f_{i, 0/1/2}\) 表示考虑了前 \(i\) 个位置时所有 \(1\) 位置的 \(a\) 之和,第三位分三种情况:
- \(i\) 处填 \(1\) 。
- \(i\) 处填 \(0\) 且 \(i - 1\) 处填 \(1\) ,需要保证 \(a_{i - 1} \ge a_i\) 。
- \(i\) 处填 \(0\) ,但 \(i - 1\) 处也填 \(0\) ,或 \(i - 1\) 处填 \(1\) 且 \(a_{i - 1} < a_i\) 。
- 此时需要下一个位置填 \(1\) 且下一个位置在 \(a\) 中的值大于该位置在 \(a\) 中的值。
转移分为以下几种:
- \(i\) 处填 \(0\) :
- \(i - 1\) 处填 \(0\) :\(f_{i - 1, 1} \to f_{i, 2}\) 。
- \(i - 1\) 处填 \(1\) :\(f_{i - 1, 0} \to f_{i, 1 + [a_i > a_{i - 1}]}\) 。
- \(i\) 处填 \(1\) (\(i - 1\) 处填 \(0\)):
- \(f_{i - 1, 1} + a_i \to f_{i, 0}\) 。
- \(f_{i - 1, 2} + a_i \to f_{i, 0}\) ,此时需要满足 \(a_i > a_{i - 1}\) 。
不难想到线段树维护 DDP 做到 \(O(n + q \log n)\) ,但是转移中并没有广义矩阵乘法的形式
由于 01 序列是唯一的,因此转移可以钦定取 \(\max\) (或取 \(\min\)),这样就可以写成广义矩阵乘法的形式。
#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const ll inf = 0x3f3f3f3f3f3f3f3f;
const int N = 1e5 + 7;
struct Matrix {
ll a[3][3];
inline Matrix() {
memset(a, -inf, sizeof(a));
}
inline Matrix operator * (const Matrix &rhs) const {
Matrix res;
for (int i = 0; i < 3; ++i)
for (int j = 0; j < 3; ++j)
for (int k = 0; k < 3; ++k)
res.a[i][k] = max(res.a[i][k], a[i][j] + rhs.a[j][k]);
return res;
}
};
int a[N];
int n, q;
namespace SMT {
Matrix mt[N << 2];
inline int ls(int x) {
return x << 1;
}
inline int rs(int x) {
return x << 1 | 1;
}
void update(int x, int nl, int nr, int p) {
if (nl == nr) {
mt[x] = Matrix(), mt[x].a[1][2] = mt[x].a[0][1 + (a[p] > a[p - 1])] = 0, mt[x].a[1][0] = a[p];
if (a[p] > a[p - 1])
mt[x].a[2][0] = a[p];
return;
}
int mid = (nl + nr) >> 1;
if (p <= mid)
update(ls(x), nl, mid, p);
else
update(rs(x), mid + 1, nr, p);
mt[x] = mt[ls(x)] * mt[rs(x)];
}
} // namespace SMT
signed main() {
scanf("%d", &n);
for (int i = 1; i <= n; ++i)
scanf("%d", a + i);
for (int i = 2; i <= n; ++i)
SMT::update(1, 2, n, i);
scanf("%d", &q);
while (q--) {
int x, k;
scanf("%d%d", &x, &k);
a[x] = k;
if (x > 1)
SMT::update(1, 2, n, x);
if (x < n)
SMT::update(1, 2, n, x + 1);
Matrix ans = SMT::mt[1];
printf("%lld\n", max(a[1] + max(ans.a[0][0], ans.a[0][1]), max(ans.a[2][0], ans.a[2][1])));
}
return 0;
}
树上动态 DP
给定一棵 \(n\) 个点的树,点带点权。
有 \(m\) 次操作,每次操作给定 \(x,y\),表示修改点 \(x\) 的权值为 \(y\)。
每次操作后求这棵树的最大权独立集的权值。
\(n, m \le 10^5\)
先考虑没有修改的情况。设 \(f_{u, 0/1}\) 表示不选/选 \(u\) 时子树内最大权独立集,则:
特殊地,若 \(u\) 为叶子节点,则 \(f_{u, 0} = 0, f_{u, 1} = a_u\) 。答案即为 \(\max(f_{1, 0}, f_{1, 1})\) 。
轻重链剖分
轻重链剖分后,考虑将每条重链视作一个序列。假设现在要处理的重链从上到下为 \(u_{1 \sim k}\) ,且每个 \(u_i\) 的所有轻儿子的 DP 值已经处理好,此时需要按 \(u_k \to u_{k - 1} \to \cdots \to u_1\) 的顺序依次计算出重链上的点的 DP 值。
设 \(u_i\) 的重儿子为 \(son_{u_i} = u_{i + 1} (i < k)\) ,则:
记 \(g_{u, 0/1}\) 表示 \(u\) 子树去掉重子树的 DP 值,则:
将其写作 \((\max, +)\) 广义矩阵乘法的形式:
对于当前重链,每个轻儿子的 \(g\) 可以视为已知量。只要能计算出重链上 \(\begin{bmatrix} g_{u_i, 0} & g_{u_i, 0} \\ g_{u_i, 1} & -\infty \end{bmatrix}\) 的乘积,就可以求出重链上任意点的 DP 值。
考虑用线段树维护重链上矩阵的乘积,单点修改时,矩阵会发生变化的点只有 \(u\) 以及 \(1 \to u\) 路径上轻边的父节点,只要在线段树中修改这 \(O(\log n)\) 个点即可,时间复杂度 \(O(m \log^2 n)\) 。
#include <bits/stdc++.h>
using namespace std;
const int inf = 0x3f3f3f3f;
const int N = 1e5 + 7;
struct Graph {
vector<int> e[N];
inline void insert(int u, int v) {
e[u].emplace_back(v);
}
} G;
struct Matrix {
int a[2][2];
inline Matrix() {
memset(a, -inf, sizeof(a));
}
inline Matrix operator * (const Matrix &rhs) {
Matrix res;
for (int i = 0; i < 2; ++i)
for (int j = 0; j < 2; ++j)
for (int k = 0; k < 2; ++k)
res.a[i][k] = max(res.a[i][k], a[i][j] + rhs.a[j][k]);
return res;
}
} mt[N];
int f[N][2], a[N], fa[N], dep[N], siz[N], son[N], top[N], bottom[N], dfn[N], id[N];
int n, m, dfstime;
void dfs1(int u, int father) {
fa[u] = father, dep[u] = dep[father] + 1, siz[u] = 1;
f[u][0] = 0, f[u][1] = a[u];
for (int v : G.e[u]) {
if (v == father)
continue;
dfs1(v, u), siz[u] += siz[v];
if (siz[v] > siz[son[u]])
son[u] = v;
f[u][0] += max(f[v][0], f[v][1]), f[u][1] += f[v][0];
}
}
void dfs2(int u, int topf) {
top[u] = topf, bottom[u] = u, id[dfn[u] = ++dfstime] = u;
if (son[u])
dfs2(son[u], topf), bottom[u] = bottom[son[u]];
int g[2] = {0, a[u]};
for (int v : G.e[u])
if (v != fa[u] && v != son[u])
dfs2(v, v), g[0] += max(f[v][0], f[v][1]), g[1] += f[v][0];
mt[u].a[0][0] = mt[u].a[0][1] = g[0];
mt[u].a[1][0] = g[1], mt[u].a[1][1] = -inf;
}
namespace SMT {
Matrix s[N << 2];
inline int ls(int x) {
return x << 1;
}
inline int rs(int x) {
return x << 1 | 1;
}
void build(int x, int l, int r) {
if (l == r) {
s[x] = mt[id[l]];
return;
}
int mid = (l + r) >> 1;
build(ls(x), l, mid), build(rs(x), mid + 1, r);
s[x] = s[ls(x)] * s[rs(x)];
}
void update(int x, int nl, int nr, int p) {
if (nl == nr) {
s[x] = mt[id[p]];
return;
}
int mid = (nl + nr) >> 1;
if (p <= mid)
update(ls(x), nl, mid, p);
else
update(rs(x), mid + 1, nr, p);
s[x] = s[ls(x)] * s[rs(x)];
}
Matrix query(int x, int nl, int nr, int l, int r) {
if (l <= nl && nr <= r)
return s[x];
int mid = (nl + nr) >> 1;
if (r <= mid)
return query(ls(x), nl, mid, l, r);
else if (l > mid)
return query(rs(x), mid + 1, nr, l, r);
else
return query(ls(x), nl, mid, l, r) * query(rs(x), mid + 1, nr, l, r);
}
} // namespace SMT
inline void update(int x, int k) {
mt[x].a[1][0] += k - a[x], a[x] = k;
while (x) {
Matrix bef = SMT::query(1, 1, n, dfn[top[x]], dfn[bottom[x]]);
SMT::update(1, 1, n, dfn[x]);
Matrix aft = SMT::query(1, 1, n, dfn[top[x]], dfn[bottom[x]]);
x = fa[top[x]];
mt[x].a[0][0] += max(aft.a[0][0], aft.a[1][0]) - max(bef.a[0][0], bef.a[1][0]);
mt[x].a[0][1] += max(aft.a[0][0], aft.a[1][0]) - max(bef.a[0][0], bef.a[1][0]);
mt[x].a[1][0] += aft.a[0][0] - bef.a[0][0];
}
}
signed main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; ++i)
scanf("%d", a + i);
for (int i = 1; i < n; ++i) {
int u, v;
scanf("%d%d", &u, &v);
G.insert(u, v), G.insert(v, u);
}
dfs1(1, 0), dfs2(1, 1), SMT::build(1, 1, n);
while (m--) {
int x, k;
scanf("%d%d", &x, &k);
update(x, k);
Matrix ans = SMT::query(1, 1, n, 1, dfn[bottom[1]]);
printf("%d\n", max(ans.a[0][0], ans.a[1][0]));
}
return 0;
}
LCT
不难发现 Splay 同样可以维护矩阵的连乘,将 LCT 的实链剖分做类似轻重链剖分的维护即可,注意修改虚实边时需要同时修改矩阵。
时间复杂度 \(O((n + m) \log n)\) ,常数较大。
#include <bits/stdc++.h>
using namespace std;
const int inf = 0x3f3f3f3f;
const int N = 1e5 + 7;
struct Graph {
vector<int> e[N];
inline void insert(int u, int v) {
e[u].emplace_back(v);
}
} G;
struct Matrix {
int a[2][2];
inline Matrix() {
memset(a, -inf, sizeof(a));
}
inline Matrix operator * (const Matrix &rhs) {
Matrix res;
for (int i = 0; i < 2; ++i)
for (int j = 0; j < 2; ++j)
for (int k = 0; k < 2; ++k)
res.a[i][k] = max(res.a[i][k], a[i][j] + rhs.a[j][k]);
return res;
}
};
int a[N], f[N][2];
int n, m;
namespace LCT {
Matrix mt[N], s[N];
int ch[N][2], fa[N];
inline int isroot(int x) {
return x != ch[fa[x]][0] && x != ch[fa[x]][1];
}
inline int dir(int x) {
return x == ch[fa[x]][1];
}
inline void pushup(int x) {
s[x] = mt[x];
if (ch[x][0])
s[x] = s[ch[x][0]] * s[x];
if (ch[x][1])
s[x] = s[x] * s[ch[x][1]];
}
inline void rotate(int x) {
int y = fa[x], z = fa[y], d = dir(x);
if (!isroot(y))
ch[z][dir(y)] = x;
fa[x] = z, ch[y][d] = ch[x][d ^ 1];
if (ch[x][d ^ 1])
fa[ch[x][d ^ 1]] = y;
ch[x][d ^ 1] = y, fa[y] = x;
pushup(y), pushup(x);
}
inline void splay(int x) {
for (int f = fa[x]; !isroot(x); rotate(x), f = fa[x])
if (!isroot(f))
rotate(dir(f) == dir(x) ? f : x);
}
inline void access(int x) {
for (int y = 0; x; y = x, x = fa[x]) {
splay(x);
if (y) {
mt[x].a[0][0] -= max(s[y].a[0][0], s[y].a[1][0]);
mt[x].a[0][1] -= max(s[y].a[0][0], s[y].a[1][0]);
mt[x].a[1][0] -= s[y].a[0][0];
}
if (ch[x][1]) {
mt[x].a[0][0] += max(s[ch[x][1]].a[0][0], s[ch[x][1]].a[1][0]);
mt[x].a[0][1] += max(s[ch[x][1]].a[0][0], s[ch[x][1]].a[1][0]);
mt[x].a[1][0] += s[ch[x][1]].a[0][0];
}
ch[x][1] = y, pushup(x);
}
}
inline void update(int x, int k) {
access(x), splay(x);
mt[x].a[1][0] += k - a[x], a[x] = k;
pushup(x);
}
} // namespace LCT
void dfs(int u, int fa) {
LCT::fa[u] = fa, f[u][0] = 0, f[u][1] = a[u];
for (int v : G.e[u])
if (v != fa)
dfs(v, u), f[u][0] += max(f[v][0], f[v][1]), f[u][1] += f[v][0];
LCT::mt[u].a[0][0] = LCT::mt[u].a[0][1] = f[u][0];
LCT::mt[u].a[1][0] = f[u][1], LCT::mt[u].a[1][1] = -inf;
LCT::s[u] = LCT::mt[u];
}
signed main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; ++i)
scanf("%d", a + i);
for (int i = 1; i < n; ++i) {
int u, v;
scanf("%d%d", &u, &v);
G.insert(u, v), G.insert(v, u);
}
dfs(1, 0);
while (m--) {
int x, k;
scanf("%d%d", &x, &k);
LCT::update(x, k), LCT::splay(1);
printf("%d\n", max(LCT::s[1].a[0][0], LCT::s[1].a[1][0]));
}
return 0;
}
全局平衡二叉树
考虑综合轻重链剖分和 LCT 的结构,对每条重链用一棵二叉树维护,然后将每条重链的二叉树用虚边连起来。
先建立重链的二叉树。定义 \(\mathrm{Lsize}(u)\) 表示 \(u\) 所有轻子树的大小和 \(+1\) ,则考虑将一条重链的二叉树按 \(\mathrm{Lsize}(u)\) 为权值建立,即每次按 \(\mathrm{Lsize}(u)\) 为权值找到带权中点,然后递归左右建树。
再连不同重链之间的虚边,只要将该重链的二叉树的根连向该重链的父亲即可。
最后考虑维护重链信息,只要忽略虚边,维护子树信息即可。
这个结构称之为全局平衡二叉树,修改时在上面暴力跳并更新信息,可以证明均摊总时间复杂度为 \(O((n + m) \log n)\) 。
证明:考虑拉出一条重链,轻子树挂在下面,则 \(\mathrm{Lsize}(u)\) 可以视为 \(u\) 的子树大小。每次深度 \(+1\) 时,子树大小都会减半,因此树高是 \(O(\log n)\) 的。
#include <bits/stdc++.h>
using namespace std;
const int inf = 0x3f3f3f3f;
const int N = 1e6 + 7;
struct Graph {
vector<int> e[N];
inline void insert(int u, int v) {
e[u].emplace_back(v);
}
} G;
struct Matrix {
int a[2][2];
inline Matrix() {
memset(a, -inf, sizeof(a));
}
inline Matrix operator * (const Matrix &rhs) {
Matrix res;
for (int i = 0; i < 2; ++i)
for (int j = 0; j < 2; ++j)
for (int k = 0; k < 2; ++k)
res.a[i][k] = max(res.a[i][k], a[i][j] + rhs.a[j][k]);
return res;
}
};
int f[N][2], a[N], siz[N], son[N], lsiz[N];
int n, m;
namespace BST {
Matrix mt[N], s[N];
int fa[N], lc[N], rc[N];
int root;
inline bool isroot(int x) {
return x != lc[fa[x]] && x != rc[fa[x]];
}
inline void pushup(int x) {
s[x] = mt[x];
if (lc[x])
s[x] = s[lc[x]] * s[x];
if (rc[x])
s[x] = s[x] * s[rc[x]];
}
int build(int l, int r, vector<int> &vec, int f) {
if (l > r)
return 0;
int all = 0;
for (int i = l; i <= r; ++i)
all += lsiz[vec[i]];
int mid = -1;
for (int i = l, sum = 0; i <= r; ++i)
if ((sum += lsiz[vec[i]]) >= all / 2) {
mid = i;
break;
}
int x = vec[mid];
fa[x] = f, lc[x] = build(l, mid - 1, vec, x), rc[x] = build(mid + 1, r, vec, x);
return pushup(x), x;
}
int build(int u, int f) {
vector<int> vec;
for (int x = u; x; x = son[x]) {
for (int v : G.e[x])
if (v != (x == u ? f : vec.back()) && v != son[x])
build(v, x);
vec.emplace_back(x);
}
return build(0, vec.size() - 1, vec, f);
}
inline void update(int x, int k) {
mt[x].a[1][0] += k - a[x], a[x] = k;
while (x) {
if (isroot(x) && fa[x]) {
Matrix bef = s[x];
pushup(x);
Matrix aft = s[x];
x = fa[x];
mt[x].a[0][0] += max(aft.a[0][0], aft.a[1][0]) - max(bef.a[0][0], bef.a[1][0]);
mt[x].a[0][1] += max(aft.a[0][0], aft.a[1][0]) - max(bef.a[0][0], bef.a[1][0]);
mt[x].a[1][0] += aft.a[0][0] - bef.a[0][0];
} else
pushup(x), x = fa[x];
}
}
} // namespace BST
void dfs(int u, int fa) {
siz[u] = 1, f[u][0] = 0, f[u][1] = a[u];
for (int v : G.e[u]) {
if (v == fa)
continue;
dfs(v, u), siz[u] += siz[v];
if (siz[v] > siz[son[u]])
son[u] = v;
f[u][0] += max(f[v][0], f[v][1]), f[u][1] += f[v][0];
}
lsiz[u] = siz[u] - siz[son[u]];
BST::mt[u].a[0][0] = BST::mt[u].a[0][1] = f[u][0] - max(f[son[u]][0], f[son[u]][1]);
BST::mt[u].a[1][0] = f[u][1] - f[son[u]][0], BST::mt[u].a[1][1] = -inf;
}
signed main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; ++i)
scanf("%d", a + i);
for (int i = 1; i < n; ++i) {
int u, v;
scanf("%d%d", &u, &v);
G.insert(u, v), G.insert(v, u);
}
dfs(1, 0), BST::root = BST::build(1, 0);
int lstans = 0;
while (m--) {
int x, k;
scanf("%d%d", &x, &k), x ^= lstans;
BST::update(x, k);
printf("%d\n", lstans = max(BST::s[BST::root].a[0][0], BST::s[BST::root].a[1][0]));
}
return 0;
}
应用
P5024 [NOIP 2018 提高组] 保卫王国
给出一棵树,定义一个驻扎方案为:
- 每个点要么不驻扎军队,要么驻扎军队。
- 每条边至少一段需要驻扎军队。
定义该方案的花费为驻扎军队的点的 \(p\) 的和。
\(m\) 次询问,每次给出 \(a, b, x, y\) 表示 \(a\) 是否强制驻扎军队与 \(b\) 是否强制驻扎军队,求最小花费方案,或告知无解。
\(n, m \le 10^5\)
问题可以补集转化为求树上最大权独立集,给出的限制可以通过给 \(a\) 或 \(b\) 的 \(p\) 赋上 \(\pm \infty\) 的值实现。
采用全局平衡二叉树实现,时间复杂度 \(O((n + m) \log n)\) 。
#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const ll inf = 1e18;
const int N = 1e6 + 7;
struct Graph {
vector<int> e[N];
inline void insert(int u, int v) {
e[u].emplace_back(v);
}
} G;
struct Matrix {
ll a[2][2];
inline Matrix() {
a[0][0] = a[0][1] = a[1][0] = a[1][1] = -inf;
}
inline Matrix operator * (const Matrix &rhs) {
Matrix res;
for (int i = 0; i < 2; ++i)
for (int j = 0; j < 2; ++j)
for (int k = 0; k < 2; ++k)
res.a[i][k] = max(res.a[i][k], a[i][j] + rhs.a[j][k]);
return res;
}
};
ll f[N][2], val[N];
int fa[N], siz[N], son[N], lsiz[N];
int n, m;
namespace BST {
Matrix mt[N], s[N];
int fa[N], lc[N], rc[N];
int root;
inline bool isroot(int x) {
return x != lc[fa[x]] && x != rc[fa[x]];
}
inline void pushup(int x) {
s[x] = mt[x];
if (lc[x])
s[x] = s[lc[x]] * s[x];
if (rc[x])
s[x] = s[x] * s[rc[x]];
}
int build(int l, int r, vector<int> &vec, int f) {
if (l > r)
return 0;
int all = 0;
for (int i = l; i <= r; ++i)
all += lsiz[vec[i]];
int mid = -1;
for (int i = l, sum = 0; i <= r; ++i)
if ((sum += lsiz[vec[i]]) >= all / 2) {
mid = i;
break;
}
int x = vec[mid];
fa[x] = f, lc[x] = build(l, mid - 1, vec, x), rc[x] = build(mid + 1, r, vec, x);
return pushup(x), x;
}
int build(int u, int f) {
vector<int> vec;
for (int x = u; x; x = son[x]) {
for (int v : G.e[x])
if (v != (x == u ? f : vec.back()) && v != son[x])
build(v, x);
vec.emplace_back(x);
}
return build(0, vec.size() - 1, vec, f);
}
inline void update(int x, ll k) {
mt[x].a[1][0] += k - val[x], val[x] = k;
while (x) {
if (isroot(x) && fa[x]) {
Matrix bef = s[x];
pushup(x);
Matrix aft = s[x];
x = fa[x];
mt[x].a[0][0] += max(aft.a[0][0], aft.a[1][0]) - max(bef.a[0][0], bef.a[1][0]);
mt[x].a[0][1] += max(aft.a[0][0], aft.a[1][0]) - max(bef.a[0][0], bef.a[1][0]);
mt[x].a[1][0] += aft.a[0][0] - bef.a[0][0];
} else
pushup(x), x = fa[x];
}
}
} // namespace BST
void dfs(int u, int fa) {
siz[u] = 1, f[u][0] = 0, f[u][1] = val[u];
for (int v : G.e[u]) {
if (v == fa)
continue;
dfs(v, u), siz[u] += siz[v];
if (siz[v] > siz[son[u]])
son[u] = v;
f[u][0] += max(f[v][0], f[v][1]), f[u][1] += f[v][0];
}
lsiz[u] = siz[u] - siz[son[u]];
BST::mt[u].a[0][0] = BST::mt[u].a[0][1] = f[u][0] - max(f[son[u]][0], f[son[u]][1]);
BST::mt[u].a[1][0] = f[u][1] - f[son[u]][0], BST::mt[u].a[1][1] = -inf;
}
signed main() {
scanf("%d%d%*s", &n, &m);
ll sum = 0;
for (int i = 1; i <= n; ++i)
scanf("%lld", val + i), sum += val[i];
for (int i = 1; i < n; ++i) {
int u, v;
scanf("%d%d", &u, &v);
G.insert(u, v), G.insert(v, u);
}
dfs(1, 0), BST::root = BST::build(1, 0);
while (m--) {
int a, x, b, y;
scanf("%d%d%d%d", &a, &x, &b, &y);
ll pa = val[a], pb = val[b], ans = sum + (inf - val[a]) * (x ^ 1) + (inf - val[b]) * (y ^ 1);
BST::update(a, x ? -inf : inf), BST::update(b, y ? -inf : inf);
ans -= max(BST::s[BST::root].a[0][0], BST::s[BST::root].a[1][0]);
printf("%lld\n", ans <= sum ? ans : -1);
BST::update(a, pa), BST::update(b, pb);
}
return 0;
}
P6021 洪水
给出一棵树,点带权,\(m\) 次操作,操作有:
- 单点加点权。
- 标记 \(u\) 子树中的若干点使得 \(u\) 与子树内的任意叶子之间的路径存在标记,最小化标记点权值和。
\(n, m \le 2 \times 10^5\)
先不考虑子树的限制,设 \(f_u\) 表示 \(u\) 子树的答案,则:
由于需要动态修改点权,因此考虑动态 DP。
设 \(g_u\) 表示 \(u\) 子树内轻儿子的 \(f\) 和,对于一条重链 \(u_{1 \sim k}\) ,考虑 \((\min, +)\) 广义矩阵乘法,则:
不难用轻重链剖分维护动态 DP 做到 \(O(m \log^2 n)\) 。
#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const ll inf = 0x3f3f3f3f3f3f3f3f;
const int N = 2e5 + 7;
struct Graph {
vector<int> e[N];
inline void insert(int u, int v) {
e[u].emplace_back(v);
}
} G;
struct Matrix {
ll a[2][2];
inline Matrix() {
memset(a, inf, sizeof(a));
}
inline Matrix operator * (const Matrix &rhs) {
Matrix res;
for (int i = 0; i < 2; ++i)
for (int j = 0; j < 2; ++j)
for (int k = 0; k < 2; ++k)
res.a[i][k] = min(res.a[i][k], a[i][j] + rhs.a[j][k]);
return res;
}
} mt[N];
ll val[N], f[N];
int fa[N], dep[N], siz[N], son[N], top[N], bottom[N], dfn[N], id[N];
int n, m, dfstime;
void dfs1(int u, int father) {
fa[u] = father, dep[u] = dep[father] + 1, siz[u] = 1;
ll sum = 0;
for (int v : G.e[u]) {
if (v == father)
continue;
dfs1(v, u), siz[u] += siz[v], sum += f[v];
if (siz[v] > siz[son[u]])
son[u] = v;
}
if (G.e[u].size() == 1)
sum = inf;
f[u] = min(val[u], sum);
mt[u].a[0][0] = sum - f[son[u]], mt[u].a[0][1] = val[u];
mt[u].a[1][0] = inf, mt[u].a[1][1] = 0;
}
void dfs2(int u, int topf) {
top[u] = topf, bottom[u] = u, id[dfn[u] = ++dfstime] = u;
if (son[u])
dfs2(son[u], topf), bottom[u] = bottom[son[u]];
for (int v : G.e[u])
if (v != fa[u] && v != son[u])
dfs2(v, v);
}
namespace SMT {
Matrix s[N << 2];
inline int ls(int x) {
return x << 1;
}
inline int rs(int x) {
return x << 1 | 1;
}
void build(int x, int l, int r) {
if (l == r) {
s[x] = mt[id[l]];
return;
}
int mid = (l + r) >> 1;
build(ls(x), l, mid), build(rs(x), mid + 1, r);
s[x] = s[ls(x)] * s[rs(x)];
}
void update(int x, int nl, int nr, int p) {
if (nl == nr) {
s[x] = mt[id[p]];
return;
}
int mid = (nl + nr) >> 1;
if (p <= mid)
update(ls(x), nl, mid, p);
else
update(rs(x), mid + 1, nr, p);
s[x] = s[ls(x)] * s[rs(x)];
}
Matrix query(int x, int nl, int nr, int l, int r) {
if (l <= nl && nr <= r)
return s[x];
int mid = (nl + nr) >> 1;
if (r <= mid)
return query(ls(x), nl, mid, l, r);
else if (l > mid)
return query(rs(x), mid + 1, nr, l, r);
else
return query(ls(x), nl, mid, l, r) * query(rs(x), mid + 1, nr, l, r);
}
} // namespace SMT
inline void update(int x, int k) {
mt[x].a[0][1] += k;
while (x) {
Matrix bef = SMT::query(1, 1, n, dfn[top[x]], dfn[bottom[x]]);
SMT::update(1, 1, n, dfn[x]);
Matrix aft = SMT::query(1, 1, n, dfn[top[x]], dfn[bottom[x]]);
mt[x = fa[top[x]]].a[0][0] += min(aft.a[0][0], aft.a[0][1]) - min(bef.a[0][0], bef.a[0][1]);
}
}
signed main() {
scanf("%d", &n);
for (int i = 1; i <= n; ++i)
scanf("%lld", val + i);
for (int i = 1; i < n; ++i) {
int u, v;
scanf("%d%d", &u, &v);
G.insert(u, v), G.insert(v, u);
}
dfs1(1, 0), dfs2(1, 1), SMT::build(1, 1, n);
scanf("%d", &m);
while (m--) {
char op[3];
scanf("%s", op);
if (op[0] == 'Q') {
int x;
scanf("%d", &x);
Matrix ans = SMT::query(1, 1, n, dfn[x], dfn[bottom[x]]);
printf("%lld\n", min(ans.a[0][0], ans.a[0][1]));
} else {
int x, k;
scanf("%d%d", &x, &k);
update(x, k);
}
}
return 0;
}
P3781 [SDOI2017] 切树游戏
给出一棵树,点带权,\(q\) 次操作,操作有:
- 单点修改点权。
- 查询树上异或和为 \(k\) 的连通块数量 \(\bmod 10007\) 。
\(n, q \le 30000\) ,点权 \(\le 127\)
设 \(f_{u, i}\) 表示 \(u\) 为最浅点的连通块异或和为 \(i\) 的方案数,转移直接做异或卷积。可以先将每个位置的点权做一遍 FWT,则转移就是单点对单点地乘,最后求解完再 IFWT 回来。
再设 \(g_{u, i} = \sum_{v \in subtree(u)} f_{u, i}\) ,则最后求出 \(g_{1, k}\) 即可。
动态修改点权则考虑矩阵乘法,记 \(f'\) 表示自己与轻儿子的 DP 值,\(g'\) 表示所有轻儿子的 DP 值,则:
由于动态 DP 会出现需要撤销某个点贡献的情况,此时若 DP 值为 \(0\) ,则无法处理。考虑将每个数记为 \(x \times 0^y\) 的形式,则不难做除法操作。
直接做矩乘会带上 \(3^3 = 27\) 的常数,可以发现这个矩阵很稀疏,并且:
因此只要维护四个位置即可,大大优化了常数。
全局平衡二叉树维护动态 DP,时间复杂度 \(O(V (n + q) \log n)\) 。
#include <bits/stdc++.h>
using namespace std;
const int Mod = 1e4 + 7, inv2 = (Mod + 1) / 2;
const int N = 3e4 + 7, V = 1 << 7;
struct Graph {
vector<int> e[N];
inline void insert(int u, int v) {
e[u].emplace_back(v);
}
} G;
int fwt[V][V], f[N][V], g[N][V], g2[N][V];
int inv[Mod], val[N], siz[N], son[N], lsiz[N];
int n, m, q;
inline int add(int x, int y) {
x += y;
if (x >= Mod)
x -= Mod;
return x;
}
inline int dec(int x, int y) {
x -= y;
if (x < 0)
x += Mod;
return x;
}
inline void FWT(int *f, int n, int op) {
for (int k = 1; k < n; k <<= 1)
for (int i = 0; i < n; i += k << 1)
for (int j = 0; j < k; ++j) {
int fl = f[i + j], fr = f[i + j + k];
f[i + j] = 1ll * add(fl, fr) * op % Mod;
f[i + j + k] = 1ll * dec(fl, fr) * op % Mod;
}
}
inline void prework() {
inv[0] = inv[1] = 1;
for (int i = 2; i < Mod; ++i)
inv[i] = 1ll * (Mod - Mod / i) * inv[Mod % i] % Mod;
for (int i = 0; i < m; ++i)
fwt[i][i] = 1, FWT(fwt[i], m, 1);
}
struct Node {
int x, y;
inline Node(int k = 0) {
if (k)
x = k, y = 0;
else
x = y = 1;
}
inline friend Node operator * (Node a, int b) {
if (b)
a.x = 1ll * a.x * b % Mod;
else
++a.y;
return a;
}
inline friend Node operator / (Node a, int b) {
if (b)
a.x = 1ll * a.x * inv[b] % Mod;
else
--a.y;
return a;
}
inline operator int() {
return y ? 0 : x;
}
} f2[N][V];
struct Matrix {
int a, b, c, d;
inline Matrix() {
a = b = c = d = 0;
}
inline Matrix(int f, int g) {
a = b = c = f, d = add(f, g);
}
inline Matrix(int _a, int _b, int _c, int _d) : a(_a), b(_b), c(_c), d(_d) {}
inline friend Matrix operator * (Matrix a, Matrix b) {
return (Matrix){1ll * a.a * b.a % Mod, add(1ll * a.a * b.b % Mod, a.b),
add(1ll * b.a * a.c % Mod, b.c), add(1ll * a.c * b.b % Mod, add(a.d, b.d))};
}
} mt[N];
namespace BST {
Matrix s[N][V];
int fa[N], lc[N], rc[N];
int root;
inline bool isroot(int x) {
return x != lc[fa[x]] && x != rc[fa[x]];
}
inline void pushup(int x) {
for (int i = 0; i < m; ++i) {
s[x][i] = Matrix(f2[x][i], g2[x][i]);
if (lc[x])
s[x][i] = s[lc[x]][i] * s[x][i];
if (rc[x])
s[x][i] = s[x][i] * s[rc[x]][i];
}
}
int build(int l, int r, vector<int> &vec, int f) {
if (l > r)
return 0;
int all = 0;
for (int i = l; i <= r; ++i)
all += lsiz[vec[i]];
int mid = -1;
for (int i = l, sum = 0; i <= r; ++i)
if ((sum += lsiz[vec[i]]) >= all / 2) {
mid = i;
break;
}
int x = vec[mid];
fa[x] = f, lc[x] = build(l, mid - 1, vec, x), rc[x] = build(mid + 1, r, vec, x);
return pushup(x), x;
}
int build(int u, int f) {
vector<int> vec;
for (int x = u; x; x = son[x]) {
for (int v : G.e[x])
if (v != (x == u ? f : vec.back()) && v != son[x])
build(v, x);
vec.emplace_back(x);
}
return build(0, vec.size() - 1, vec, f);
}
inline void update(int x, int k) {
for (int i = 0; i < m; ++i)
f2[x][i] = f2[x][i] / fwt[val[x]][i] * fwt[k][i];
val[x] = k;
while (x) {
if (isroot(x) && fa[x]) {
for (int i = 0; i < m; ++i) {
f2[fa[x]][i] = f2[fa[x]][i] / add(s[x][i].b, 1);
g2[fa[x]][i] = dec(g2[fa[x]][i], s[x][i].d);
}
pushup(x);
for (int i = 0; i < m; ++i) {
f2[fa[x]][i] = f2[fa[x]][i] * add(s[x][i].b, 1);
g2[fa[x]][i] = add(g2[fa[x]][i], s[x][i].d);
}
x = fa[x];
} else
pushup(x), x = fa[x];
}
}
inline int query(int k) {
static int f[V];
for (int i = 0; i < m; ++i)
f[i] = s[root][i].d;
return FWT(f, m, inv2), f[k];
}
} // namespace BST
void dfs(int u, int fa) {
siz[u] = 1;
memcpy(f[u], fwt[val[u]], sizeof(int) * m);
for (int v : G.e[u]) {
if (v == fa)
continue;
dfs(v, u), siz[u] += siz[v];
if (siz[v] > siz[son[u]])
son[u] = v;
for (int i = 0; i < m; ++i)
f[u][i] = add(f[u][i], 1ll * f[u][i] * f[v][i] % Mod), g[u][i] = add(g[u][i], g[v][i]);
}
lsiz[u] = siz[u] - siz[son[u]];
for (int i = 0; i < m; ++i)
g[u][i] = add(g[u][i], f[u][i]), f2[u][i] = Node(fwt[val[u]][i]);
for (int v : G.e[u])
if (v != fa && v != son[u]) {
for (int i = 0; i < m; ++i)
f2[u][i] = f2[u][i] * add(f[v][i], 1), g2[u][i] = add(g2[u][i], g[v][i]);
}
}
signed main() {
scanf("%d%d", &n, &m);
prework();
for (int i = 1; i <= n; ++i)
scanf("%d", val + i);
for (int i = 1; i < n; ++i) {
int u, v;
scanf("%d%d", &u, &v);
G.insert(u, v), G.insert(v, u);
}
dfs(1, 0), BST::root = BST::build(1, 0);
scanf("%d", &q);
while (q--) {
char op[7];
scanf("%s", op);
if (op[0] == 'Q') {
int x;
scanf("%d", &x);
printf("%d\n", BST::query(x));
} else {
int x, k;
scanf("%d%d", &x, &k);
BST::update(x, k);
}
}
return 0;
}
P8820 [CSP-S 2022] 数据传输
给出一棵树,规定两个点可以互相传送信息当且仅当两个点的距离不超过 \(k\) 。
\(q\) 次询问,每次给出 \(s, t\) ,求一段传送途径 \(\{x_1 = s, x_2, \cdots, x_{k - 1}, x_k = t \}\) 满足相邻两点可以互相传送信息,最小化 \(\sum_{i = 1}^k a_{x_i}\) 。
\(n, q \le 2 \times 10^5\) ,\(k \le 3\)
考虑询问时把这条链拉出来 DP,记这条链为 \(Path(s \to t) = \{ u_1 = s, u_2, \cdots, u_{k - 1}, u_k = t \}\) 。
当 \(k = 1\) 时,即查询链上权值和。
当 \(k = 2\) 时,设 \(f_i\) 表示 \(s \to u_i\) 的答案,则 \(f_i = a_i + \min(f_{i - 1}, f_{i - 2})\) 。
当 \(k = 3\) 时,设 \(f_{i, 0/1/2}\) 表示 \(s\) 到距离 \(u_i\) 距离为 \(0/1/2\) 的、所属链上的点的编号 \(\le i\) 的点的最小权值和,\(b_u\) 表示 \(u\) 邻域的最小权值,则:
以上操作不难用 \(k \times k\) 的矩阵维护,树上倍增维护矩阵乘积,时间复杂度 \(O((n + q) k^3 \log n)\) 。
#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const ll inf = 0x3f3f3f3f3f3f3f3f;
const int N = 2e5 + 7, M = 3, LOGN = 19;
struct Graph {
vector<int> e[N];
inline void insert(int u, int v) {
e[u].emplace_back(v);
}
} G;
ll a[N];
int fa[N][LOGN], dep[N];
int n, q, m, dfstime;
struct Matrix {
ll a[M][M];
inline Matrix(ll k = inf) {
memset(a, inf, sizeof(a));
for (int i = 0; i < m; ++i)
a[i][i] = k;
}
inline Matrix operator * (const Matrix &rhs) {
Matrix res;
for (int i = 0; i < m; ++i)
for (int j = 0; j < m; ++j)
for (int k = 0; k < m; ++k)
res.a[i][k] = min(res.a[i][k], a[i][j] + rhs.a[j][k]);
return res;
}
} mt[N], up[N][LOGN], down[N][LOGN];
void dfs(int u, int f) {
fa[u][0] = f, dep[u] = dep[f] + 1, up[u][0] = down[u][0] = mt[u];
for (int i = 1; i < LOGN; ++i) {
fa[u][i] = fa[fa[u][i - 1]][i - 1];
up[u][i] = up[u][i - 1] * up[fa[u][i - 1]][i - 1];
down[u][i] = down[fa[u][i - 1]][i - 1] * down[u][i - 1];
}
for (int v : G.e[u])
if (v != f)
dfs(v, u);
}
inline int LCA(int x, int y) {
if (dep[x] < dep[y])
swap(x, y);
for (int i = 0, h = dep[x] - dep[y]; h; ++i, h >>= 1)
if (h & 1)
x = fa[x][i];
if (x == y)
return x;
for (int i = LOGN - 1; ~i; --i)
if (fa[x][i] != fa[y][i])
x = fa[x][i], y = fa[y][i];
return fa[x][0];
}
inline Matrix query(int x, int y) {
int lca = LCA(x, y);
Matrix resx(0), resy(0);
if (x != lca) {
x = fa[x][0];
while (x != lca) {
int d = __lg(dep[x] - dep[lca]);
resx = resx * up[x][d], x = fa[x][d];
}
resx = resx * mt[lca];
}
while (y != lca) {
int d = __lg(dep[y] - dep[lca]);
resy = down[y][d] * resy, y = fa[y][d];
}
return resx * resy;
}
signed main() {
scanf("%d%d%d", &n, &q, &m);
for (int i = 1; i <= n; ++i)
scanf("%lld", a + i);
for (int i = 1; i < n; ++i) {
int u, v;
scanf("%d%d", &u, &v);
G.insert(u, v), G.insert(v, u);
}
for (int i = 1; i <= n; ++i) {
for (int j = 0; j + 1 < m; ++j)
mt[i].a[j][j + 1] = 0;
for (int j = 0; j < m; ++j)
mt[i].a[j][0] = a[i];
if (m == 3) {
for (int x : G.e[i])
mt[i].a[1][1] = min(mt[i].a[1][1], a[x]);
}
}
dfs(1, 0);
while (q--) {
int s, t;
scanf("%d%d", &s, &t);
Matrix f;
f.a[0][0] = a[s], f = f * query(s, t);
printf("%lld\n", f.a[0][0]);
}
return 0;
}