动态dp (Ddp)

就用板子题为例吧。

首先考虑没有更改怎么做,那就是没有上司的舞会,我们设 \(f[u][0/1]\) 表示现在在点 \(u\) ,0是不选自己,1是选自己,转移就是

\[f_{u,0} = \sum_{v \in son_u} \max(f_{v,1},f_{v,0}) \]

\[f_{u,1} = \sum_{v \in son_u} f_{v,0} \]

然后考虑加上更改怎么做,显然加上更改之后只会对更改的这个点以及他的祖先们的dp值有影响,所以我们可以很自然地想到重链剖分,对于整个链一起更改。

然后考虑怎么对于一整个链快速更改,这里我们设一个数组 \(g[u][0/1]\) ,表示这个点的轻儿子们全都不限制/不选的答案,那么现在就变成了

\[f_{u,0} = g_{u,0} + \max(f_{j,0}, f_{j,1}) \]

\[f_{u,1} = g_{u,1} + a_u + f_{j,0} \]

这里的 \(j\) 表示 \(i\) 的重儿子,然后我们发现 \(g_{u,1}\)\(a_u\) 可以合并到一起,然后这个式子就变成了一个非常美妙的东西:

\[f_{u,0} = \max(g_{u,0} + f_{j,0}, g_{u,0} + f_{j,1}) \]

\[f_{u,1} = \max(g_{u,1} + f_{j,0},-\infty) \]

这样这个式子就比较可以快速转移了,参考矩阵乘法,我们可以把乘号重新定义,定义为取 \(max\) ,然后你发现这玩意其实很像一个矩阵,具体就是

\[\begin{bmatrix} g_{u,0} & g_{u,1} \\ g_{u,0} & -\infty \end{bmatrix} \]

然后你就在不断地更新的时候顺便更新矩阵就行了,但是要注意一个点,因为你是用矩阵从链底转移的,所以你每次查询的时候也要从这个点所在的链底往上转移到这个点才有最终答案。

贴一下代码

#include <cstring>
#include <iostream>

using namespace std;
const int N = 1e6 + 10;

int n, m, tot, cnt, Ans;
int head[N], fa[N], sz[N], mxs[N], dep[N];
int a[N], id[N], dfn[N], top[N], End[N];
int f[N][2];
struct Map { int to, nxt; } e[N << 1];

void add (int u, int v) {
    e[++tot] = {v, head[u]};
    head[u] = tot;
}

struct Matrix {
    int mat[2][2];

    Matrix () {
        for (int i = 0; i < 2; ++i)
            for (int j = 0; j < 2; ++j)
                mat[i][j] = -1e9;
    }

    Matrix operator * (Matrix b) {
        Matrix c;
        for (int i = 0; i < 2; ++i)
            for (int j = 0; j < 2; ++j)
                for (int k = 0; k < 2; ++k)
                    c.mat[i][j] = max (c.mat[i][j], mat[i][k] + b.mat[k][j]);
        return c;
    }
} val[N], ans;

struct Seg {
    #define ls (x << 1)
    #define rs (x << 1 | 1)

    int l[N << 2], r[N << 2];
    Matrix M[N << 2];

    void pushup (int x) {
        M[x] = M[ls] * M[rs];
    }

    void build (int x, int L, int R) {
        l[x] = L, r[x] = R;
        if (L == R) {
            return M[x] = val[dfn[L]], void();
        } int mid = (L + R) >> 1;
        build (ls, L, mid), build (rs, mid + 1, R);
        pushup (x);
    }

    void update (int k, int x) {
        if (l[x] == r[x]) {
            return M[x] = val[dfn[k]], void();
        } int mid = (l[x] + r[x]) >> 1;
        if (k <= mid) update (k, ls);
        else update (k, rs);
        pushup (x);
    }

    Matrix query (int x, int L, int R) {
        if (L == l[x] && R == r[x])
            return M[x];
        int mid = (l[x] + r[x]) >> 1;
        if (R <= mid) return query (ls, L, R);
        else if (L > mid) return query (rs, L, R);
        else return query (ls, L, mid) * query (rs, mid + 1, R);
    }
} t;

void dfs1 (int u) {
    sz[u] = 1;
    for (int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to;
        if (v == fa[u]) continue;
        fa[v] = u, dep[v] = dep[u] + 1;
        dfs1 (v), sz[u] += sz[v];
        if (sz[v] > sz[mxs[u]]) mxs[u] = v;
    }
}

void dfs2 (int u, int topf) {
    id[u] = ++cnt, dfn[cnt] = u;
    top[u] = topf, End[topf] = max (End[topf], cnt);
    f[u][0] = 0, f[u][1] = a[u];
    val[u].mat[0][0] = val[u].mat[0][1] = 0;
    val[u].mat[1][0] = a[u];
    if (mxs[u]) {
        dfs2 (mxs[u], topf);
        f[u][0] += max (f[mxs[u]][0], f[mxs[u]][1]);
        f[u][1] += f[mxs[u]][0];
    }
    for (int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to;
        if (v == fa[u] || v == mxs[u]) continue;
        dfs2 (v, v);
        f[u][0] += max(f[v][0], f[v][1]);
        f[u][1] += f[v][0];
        val[u].mat[0][0] += max(f[v][0], f[v][1]);
        val[u].mat[0][1] = val[u].mat[0][0], val[u].mat[1][0] += f[v][0];
    }
}

void update (int u, int w) {
    val[u].mat[1][0] += w - a[u], a[u] = w;
    Matrix bef, aft;
    while (u != 0) {
        bef = t.query (1, id[top[u]], End[top[u]]);
        t.update (id[u], 1);
        aft = t.query (1, id[top[u]], End[top[u]]);
        u = fa[top[u]];
        val[u].mat[0][0] += max (aft.mat[0][0], aft.mat[1][0]) - max (bef.mat[0][0], bef.mat[1][0]);
        val[u].mat[0][1] = val[u].mat[0][0], val[u].mat[1][0] += aft.mat[0][0] - bef.mat[0][0];
    }
}

inline int read() {
	int x = 0; bool f = false; char c = getchar();
	while (!isdigit(c)) {if (c == '-') f = true; c = getchar();}
	while (isdigit(c)) x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
	return (f ? -x : x);
}

int main () {
    n = read(), m = read();
    for (int i = 1; i <= n; ++i)
        a[i] = read();
    for (int i = 1, u, v; i < n; ++i)
        u = read(), v = read(), add (u, v), add (v, u);
    dfs1 (1), dfs2 (1, 1), t.build (1, 1, n);
    int x, v;
    while (m--) {
        x = read(), v = read(), x = x ^ Ans;
        update (x, v);
        ans = t.query (1, id[1], End[1]);
        Ans = max (ans.mat[0][0], ans.mat[1][0]);
        printf ("%d\n", max (ans.mat[0][0], ans.mat[1][0]));
    } return 0;
}
posted @ 2025-05-10 11:10  Rose_Lu  阅读(43)  评论(2)    收藏  举报