模板索引:树论

树的直径

例题 SP1437 PT07Z - Longest path in a tree

输出直径长度

点击查看代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1e4+10;
int n, dis[maxn];
int A, B;
vector<int> mp[maxn];

void dfs(int x, int fa){
    dis[x] = dis[fa] + 1;
    for(auto to : mp[x]){
        if(to == fa) continue;
        dfs(to, x); 
    }
}


int main(){
    cin >> n;
    for(int i = 1; i < n; i++){
        int u = read(), v = read();
        mp[u].push_back(v);
        mp[v].push_back(u);}
    // 找到直径的一端
    dfs(1, 0);
    for(int i = 1; i <= n; i++) 
        if(dis[i] > dis[A]) A = i;
    dis[A] = 0; dfs(A, 0);
    //从直径的一端开始找另一端
    int ans = 0;
    for(int i = 1; i <= n; i++) 
        if(dis[i] > dis[B]) B = i, ans = dis[B] - dis[A];
    cout << ans << endl;
	return 0;
}

树的重心

重链剖分 & 树上差分 & LCA

笔记

  • 注意rdfn和dfs序的转化使用
  • 注意赋值懒标记有的时候才下传,没有的时候不下传。模板代码的pushdown函数没有判断if(!lzy[u]) return;却不出错的原因是这个标记只处理区间加法

P3128 [USACO15DEC] Max Flow P

FJ给他的牛棚的 𝑁个隔间之间安装了 𝑁−1根管道,隔间编号从 1到 𝑁。所有隔间都被管道连通了。
FJ有 𝐾条运输牛奶的路线,第 𝑖 条路线从隔间 𝑠𝑖 运输到隔间 𝑡𝑖。一条运输路线会给它的两个端点处的隔间以及中间途径的所有隔间带来一个单位的运输压力,你需要计算压力最大的隔间的压力是多少

注意此题为点差分

点击查看代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 5e4+9; // 点差分模板

int N, T;
vector <int> vec[maxn];
int val[maxn];
int fa[maxn], siz[maxn], son[maxn];
int dep[maxn], top[maxn];


void dfs1(int x, int f){
    fa[x] = f; siz[x] = 1; dep[x] = dep[f] + 1;
    for(auto to : vec[x]){
        if(to == f) continue;
        dfs1(to, x);
        siz[x] += siz[to];
        if(siz[to] > siz[son[x]]) son[x] = to;
    }
}

void dfs2(int x, int tp){
    top[x] = tp;
    if(!son[x]) return ;
    dfs2(son[x], tp);
    for(auto to : vec[x]){
        if(to == fa[x] || to == son[x]) continue;
        dfs2(to, to);
    }
}

int ans;

void dfs3(int x){
    for(auto to : vec[x]){
        if(to == fa[x]) continue;
        dfs3(to);
        val[x] += val[to];
    }
    ans = max(ans, val[x]);
}

int LCA(int x, int y){
    while(top[x] != top[y]){
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        x = fa[top[x]];
    }
    if(dep[x] < dep[y]) return x;
    return y;
}

int main(){
    cin >> N >> T;
    for(int i = 1, x, y; i < N; i++){
        x = read(); y = read();
        vec[x].push_back(y);
        vec[y].push_back(x);
    }

    dfs1(1, 0);
    dfs2(1, 1);
    while(T--){
        int x = read(), y = read();
        int lca = LCA(x, y);
        val[x]++; val[y]++;
        val[lca]--; val[fa[lca]]--;

    }
    dfs3(1);
    cout << ans << endl;
	return 0;
}

P4114 Qtree1

边差分与区间最值模板

给定一棵 \(n\) 个节点的树,有两种操作:

  • CHANGE i t 把第 \(i\) 条边的边权变成 \(t\)
  • QUERY a b 输出从 \(a\)\(b\) 的路径上最大的边权。当 \(a=b\) 时,输出 \(0\)
  • **一定要记住绑定到子节点后,u 到 v 的路径上没有 lca(u, v) **
点击查看代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 2e5 + 10;
int n;
int head[maxn], to[maxn], nxt[maxn], val[maxn], tot;
int a[maxn];

void adde(int u, int v, int w){
    nxt[++tot] = head[u];
    to[head[u] = tot] = v;
    val[tot] = w;
}

int f[maxn], dth[maxn], siz[maxn], wc[maxn];

void dfs1(int u, int fa){
    dth[u] = dth[fa] + 1;
    f[u] = fa;
    siz[u] = 1;
    for(int i = head[u]; i; i = nxt[i]){
        int v = to[i];
        if(v == fa) continue;
        a[v] = val[i];
        dfs1(v, u);
        siz[u] += siz[v];
        if(siz[v] > siz[wc[u]]) wc[u] = v;
    }
}

int top[maxn], dfn[maxn], rdfn[maxn], cdfn;

void dfs2(int u, int Top){
    top[u] = Top;
    dfn[u] = ++cdfn;
    rdfn[cdfn] = u;
    if(wc[u] == 0) return ;
    dfs2(wc[u], Top);
    for(int i = head[u]; i; i = nxt[i]){
        int v = to[i];
        if(v == f[u] || v == wc[u]) continue;
        dfs2(v, v);
    }
}

ll w[maxn * 4];

void pushup(int u) {w[u] = max(w[u << 1], w[u << 1 | 1]);}
void build(int u, int L, int R){
    if(L == R){
        w[u] = a[rdfn[L]];
        return; 
    }
    int mid = (L + R) >> 1;
    build(u << 1, L, mid);
    build(u << 1 | 1, mid + 1, R);
    pushup(u);
}

bool inrange(int L, int R, int l, int r){return l <= L && R <= r;}
bool outofrange(int L, int R, int l, int r){return r < L || R < l;}


int query(int u, int L, int R, int l, int r){
    if(inrange(L, R, l, r)) return w[u];
    else if(outofrange(L, R, l, r)) return 0;
    else {
        int mid = (L + R) >> 1;
        return max(query(u << 1, L, mid, l, r), query(u << 1 | 1, mid + 1, R, l, r));
    }
}

void ddxg(int u, int L, int R, int p, ll x){
    if(L == R){
        w[u] = x;
        return ;
    }
    int mid = (L + R) >> 1;
    if(p <= mid) ddxg(u << 1, L, mid, p, x);
    else ddxg(u << 1 | 1, mid + 1, R, p, x);
    pushup(u);
}

int qry(int x, int y){
    int ans = 0;
    while(top[x] != top[y]){
        if(dth[top[x]] < dth[top[y]]) swap(x, y);
        ans = max(ans, query(1, 1, n, dfn[top[x]], dfn[x]));
        x = f[top[x]];
    }
    if(x == y) return ans;
    return max(ans, query(1, 1, n, min(dfn[x], dfn[y]) + 1, max(dfn[x], dfn[y])));
    //关于+1: 绑定后 u 到 v 的路径上没有 lca(u, v)
}

int work(int x){
    x *= 2;
    int v = to[x];
    int u = to[x - 1];
    if(f[u] == v) swap(u, v);
    return dfn[v]; 
}

int main(){
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
    cin >> n;
    for(int i = 1; i < n; i++){
        int u, v, w;
        cin >> u >> v >> w;
        adde(u, v, w);
        adde(v, u, w);
    }

    dfs1(1, 0);
    dfs2(1, 1);
    build(1, 1, n);

    string str;
    while(cin >> str){
        if(str == "DONE") return 0;
        int x, y;
        cin >> x >> y;
        if(str == "CHANGE") ddxg(1, 1, n, work(x), y);
        else if(x == y) cout << 0 << endl;
        else cout << qry(x, y) << endl;
        
    }

	return 0;
}

P3384 【模板】重链剖分/树链剖分

如题,已知一棵包含 \(N\) 个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:

  • 1 x y z,表示将树从 \(x\)\(y\) 结点最短路径上所有节点的值都加上 \(z\)
  • 2 x y,表示求树从 \(x\)\(y\) 结点最短路径上所有节点的值之和。
  • 3 x z,表示将以 \(x\) 为根节点的子树内所有节点值都加上 \(z\)
  • 4 x,表示求以 \(x\) 为根节点的子树内所有节点值之和。
点击查看代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1e5 + 10;
// const int maxm = 2e5 + 10;
int p;
int n, m, r;
int fa[maxn], siz[maxn], dep[maxn], wson[maxn], a[maxn];

vector< int > mp[maxn];

void dfs1(int u, int f){
    fa[u] = f;
    siz[u] = 1;
    dep[u] = dep[f] + 1;
    for(auto v : mp[u]){
        if(v == f) continue;
        dfs1(v, u);
        siz[u] += siz[v];
        if(siz[v] > siz[wson[u]]) wson[u] = v; // 确定重孩子
    }
}

int dfn[maxn], vistime, rdfn[maxn], top[maxn];
void dfs2(int u, int Top) {
    dfn[u] = ++vistime; // 确定dfn序,由于dfn序依赖重儿子的确定,所以必须放在dfs2   !!!
    rdfn[vistime] = u; // 反向确定DFS序第 vistime 个结点是 u
    top[u] = Top;
    if(wson[u] == 0) return ;
    dfs2(wson[u], Top);
    for(auto v : mp[u]){
        if(v == fa[u] || v == wson[u]) continue;
        dfs2(v, v);
    }
}

// 线段树部分
ll w[maxn * 4], lzy[maxn * 4];
void pushup(int u) {w[u] = (w[u << 1] + w[u << 1 | 1]) % p;}

void build(int u, int L, int R){
    if(L == R) {
        w[u] = a[rdfn[L]]; // 到达叶节点,该区间的点权是DFS序上第L个结点的权值, 即a[rdfn[L]]  !!!
        return ;
    }
    int M = (L + R) >> 1;
    build(u << 1, L, M); build(u << 1 | 1, M + 1, R);
    pushup(u);
}

bool InRange(int L, int R, int l, int r) {return l <= L && R <= r;}
bool OutofRange(int L, int R, int l, int r) {return r < L || R < l;}

void maketag(int u, int len, ll x){
    lzy[u] += x;
    w[u] += x * len % p;
    lzy[u] %= p;
}

void pushdown(int u, int L, int R){
    if(lzy[u] == 0) return; // 如果没有懒标记,直接返回
    int M = (L + R) >> 1;
    maketag(u << 1, M - L + 1, lzy[u]);
    maketag(u << 1 | 1, R - M, lzy[u]);
    lzy[u] = 0;
}

ll qjcx(int u, int L, int R, int l, int r){
    if(InRange(L, R, l, r)) return w[u];
    if(OutofRange(L, R, l, r)) return 0;
    pushdown(u, L, R);
    int M = (L + R) >> 1;
    return (qjcx(u << 1, L, M, l, r) + qjcx(u << 1 | 1, M + 1, R, l, r)) % p;
}

void qjxg(int u, int L, int R, int l, int r, ll x){
    if(InRange(L, R, l, r)) {
        maketag(u, R - L + 1, x);
        return;
    }
    if(OutofRange(L, R, l, r)) return;
    pushdown(u, L, R);
    int M = (L + R) >> 1;
    qjxg(u << 1, L, M, l, r, x);
    qjxg(u << 1 | 1, M + 1, R, l, r, x);
    pushup(u);
}

void upd(int x, int y, ll z){
    while(top[x] != top[y]){
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        qjxg(1, 1, n, dfn[top[x]], dfn[x], z); // 把跳链这段区间先给修改了
        x = fa[top[x]];
    }
    qjxg(1, 1, n, min(dfn[x], dfn[y]), max(dfn[x], dfn[y]), z); // 最后在同一个链上
}

ll qry(int x, int y){
    ll res = 0;
    while(top[x] != top[y]){
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        res += qjcx(1, 1, n, dfn[top[x]], dfn[x]); // 注意较深的点 DFS 序大,整条链是从上到下铺展
        x = fa[top[x]];
    }
    res += qjcx(1, 1, n, min(dfn[x], dfn[y]), max(dfn[x], dfn[y]));
    return res % p;
}

int main(){
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
    cin >> n >> m >> r >> p;
    for(int i = 1; i <= n; i++) cin >> a[i];
    for(int i = 1; i < n; i++){
        int u, v;
        cin >> u >> v;
        mp[u].push_back(v);
        mp[v].push_back(u);
    }
    dfs1(r, 0);
    dfs2(r, r);
    build(1, 1, n);
    for(int op, x, y, z; m; --m){
        cin >> op;
        if(op == 1){
            cin >> x >> y >> z;
            upd(x, y, z); // 表示将树从 x 到 y 结点最短路径上所有节点的值都加上 z
        }
        else if(op == 2){
            cin >> x >> y;
            cout << qry(x, y) << "\n"; // 表示求树从 x 到 y 结点最短路径上所有节点的值之和。
        }
        else if(op == 3){
            cin >> x >> z;
            qjxg(1, 1, n, dfn[x], dfn[x] + siz[x] - 1, z); // 表示将以 x 为根节点的子树内所有节点值都加上 z。
            // 这里很巧妙的利用 DFS 序,一个结点子树内的 DFS 序是连续的
        }
        else {
            cin >> x;
            cout << qjcx(1, 1, n, dfn[x], dfn[x] + siz[x] - 1) % p << "\n"; // 表示求以 x 为根节点的子树内所有节点的值之和。
        }

    }

	return 0;
}

树上二分

P2680 [NOIP 2015 提高组] 运输计划

最小化最大给定路径点权和

点击查看代码

#include<bits/stdc++.h>
using namespace std;
const int maxn = 3e5 + 10;
const int maxm = 6e5 + 10;

int n, m;
int head[maxn], to[maxm], nxt[maxm], eval[maxm], tot;
int fa[maxn], dep[maxn], siz[maxn], son[maxn], top[maxn];
int val[maxn], dis[maxn]; // 关键修改1: 增加dis数组存储根到节点的距离
int uu[maxn], vv[maxn], lcaa[maxn], path_len[maxn]; // 关键修改2: 存储所有路径信息
int diff[maxn], rak[maxn], dfn[maxn], cnt; // 关键修改3: 增加树上差分相关数组

void adde(int u, int v, int w) {
    nxt[++tot] = head[u];
    to[head[u] = tot] = v;
    eval[tot] = w;
}

// 关键修改4: 重构DFS1,添加距离计算
void dfs1(int u, int f, int d) {
    fa[u] = f; dep[u] = d; siz[u] = 1;
    son[u] = 0;
    for (int i = head[u]; i; i = nxt[i]) {
        int v = to[i];
        if (v == f) continue;
        val[v] = eval[i]; // 边权赋给子节点
        dis[v] = dis[u] + val[v]; // 计算从根到v的距离
        dfs1(v, u, d + 1);
        siz[u] += siz[v];
        if (siz[v] > siz[son[u]]) son[u] = v;
    }
}

// 关键修改5: 重构DFS2,添加DFS序记录
void dfs2(int u, int tp) {
    top[u] = tp;
    dfn[u] = ++cnt;   // DFS序
    rak[cnt] = u;     // DFS序对应的节点
    if (!son[u]) return;
    dfs2(son[u], tp);
    for (int i = head[u]; i; i = nxt[i]) {
        int v = to[i];
        if (v == fa[u] || v == son[u]) continue;
        dfs2(v, v);
    }
}

// 关键修改6: 修复LCA函数
int LCA(int x, int y) {
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) swap(x, y);
        x = fa[top[x]];
    }
    return dep[x] < dep[y] ? x : y;
}

// 关键修改7: 实现二分答案检查函数
bool check(int mid) {
    memset(diff, 0, sizeof(diff)); // 初始化差分数组
    int cnt_path = 0, max_len = 0; // cnt_path: 超过mid的路径数,max_len: 这些路径中的最大长度
    
    // 标记所有超过mid的路径
    for (int i = 1; i <= m; i++) {
        if (path_len[i] > mid) {
            cnt_path++;
            max_len = max(max_len, path_len[i]); // 更新最大长度
            diff[uu[i]]++;    // 路径起点+1
            diff[vv[i]]++;    // 路径终点+1
            diff[lcaa[i]] -= 2; // LCA处-2
        }
    }
    
    // 如果没有超过mid的路径,直接返回true
    if (!cnt_path) return true;
    
    // 关键修改8: 通过DFS序逆序累加差分值(从叶子到根)
    for (int i = n; i >= 1; i--) {
        int u = rak[i]; // 获取DFS序为i的节点
        diff[fa[u]] += diff[u]; // 向父节点累加
    }
    
    // 关键修改9: 检查是否存在满足条件的边
    for (int i = 2; i <= n; i++) { // 从2开始,根节点没有对应的边
        // 条件1: 该边被所有超过mid的路径覆盖
        // 条件2: 该边权值足够大,使得删除后最长路径<=mid
        if (diff[i] == cnt_path && val[i] >= max_len - mid)
            return true;
    }
    return false;
}

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i < n; i++) {
        int u, v, w;
        scanf("%d%d%d", &u, &v, &w);
        adde(u, v, w);
        adde(v, u, w);
    }
    
    // 关键修改10: 执行两次DFS进行树链剖分
    dfs1(1, 0, 1);
    dfs2(1, 1);
    
    int max_len = 0; // 存储所有路径中的最大长度
    // 预处理所有路径信息
    for (int i = 1; i <= m; i++) {
        scanf("%d%d", &uu[i], &vv[i]);
        lcaa[i] = LCA(uu[i], vv[i]); // 计算LCA
        // 计算路径长度: dis[u] + dis[v] - 2*dis[lca]
        path_len[i] = dis[uu[i]] + dis[vv[i]] - 2 * dis[lcaa[i]];
        max_len = max(max_len, path_len[i]); // 更新最大路径长度
    }
    
    // 关键修改11: 二分答案框架
    int l = 0, r = max_len; // 答案范围[0, 最长路径长度]
    while (l < r) {
        int mid = (l + r) >> 1;
        if (check(mid)) r = mid; // 满足条件,尝试更小的答案
        else l = mid + 1;        // 不满足条件,需要增大答案
    }
    printf("%d\n", l);
    return 0;
}

树上路径

P2486 [SDOI2011] 染色

区间赋值、求颜色连续段数量

点击查看代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1e5 + 10;
// const int maxm = 2e5 + 10;

int n, m, r;
int fa[maxn], siz[maxn], dep[maxn], wson[maxn], a[maxn];

vector< int > mp[maxn];

void dfs1(int u, int f){
    fa[u] = f;
    siz[u] = 1;
    dep[u] = dep[f] + 1;
    for(auto v : mp[u]){
        if(v == f) continue;
        dfs1(v, u);
        siz[u] += siz[v];
        if(siz[v] > siz[wson[u]]) wson[u] = v; // 确定重孩子
    }
}

int dfn[maxn], vistime, rdfn[maxn], top[maxn];
void dfs2(int u, int Top) {
    dfn[u] = ++vistime; // 确定dfn序,由于dfn序依赖重儿子的确定,所以必须放在dfs2   !!!
    rdfn[vistime] = u; // 反向确定DFS序第 vistime 个结点是 u
    top[u] = Top;
    if(wson[u] == 0) return ;
    dfs2(wson[u], Top);
    for(auto v : mp[u]){
        if(v == fa[u] || v == wson[u]) continue;
        dfs2(v, v);
    }
}

// 线段树部分
struct Node {
    int lcol, rcol;
    ll seg;
    int lzy;
    Node(ll seg = 0, int lzy = 0) : seg(seg), lzy(lzy) {}
} t[maxn * 4];

void pushup(int u) {
    t[u].lcol = t[u << 1].lcol;
    t[u].rcol = t[u << 1 | 1].rcol;
    t[u].seg = t[u << 1].seg + t[u << 1 | 1].seg - (t[u << 1].rcol == t[u << 1 | 1].lcol);
}


void build(int u, int L, int R){
    t[u].lzy = 0;
    if (L == R) {
        t[u].lcol = t[u].rcol = a[rdfn[L]];
        t[u].seg = 1;
        return;
    }
    int M = (L + R) >> 1;
    build(u << 1, L, M); build(u << 1 | 1, M + 1, R);
    pushup(u);
}

bool InRange(int L, int R, int l, int r) {return l <= L && R <= r;}
bool OutofRange(int L, int R, int l, int r) {return r < L || R < l;}

void maketag(int u, int L, int R, int c) {
    t[u].lcol = t[u].rcol = c;
    t[u].seg = 1;
    t[u].lzy = c;
}

void pushdown(int u, int L, int R) {
    if (!t[u].lzy) return;
    int M = (L + R) >> 1;
    maketag(u << 1, L, M, t[u].lzy);
    maketag(u << 1 | 1, M + 1, R, t[u].lzy);
    t[u].lzy = 0;
}

struct Res {
    int lcol, rcol;
    ll seg;
    Res(int l=0, int r=0, ll s=0): lcol(l), rcol(r), seg(s) {}
};

Res qjcx(int u, int L, int R, int l, int r) {
    if (l > R || r < L) return Res(-1, -1, 0);// 不相交,返回空段
    if (l <= L && R <= r) return Res(t[u].lcol, t[u].rcol, t[u].seg);
    pushdown(u, L, R);
    int M = (L + R) >> 1;
    Res left = qjcx(u << 1, L, M, l, r);
    Res right = qjcx(u << 1 | 1, M + 1, R, l, r);
    // 注意这里,显然不会存在一个段seg=1,所以这里seg=0是标记的空段,并不包含在目标区间内,所以要特判避免干扰我们的左右端点颜色
    if (left.seg == 0) return right;
    if (right.seg == 0) return left;
    ll seg = left.seg + right.seg - (left.rcol == right.lcol);
    return Res(left.lcol, right.rcol, seg);
}

void qjxg(int u, int L, int R, int l, int r, int c) {
    if (l > R || r < L) return;
    if (l <= L && R <= r) {
        maketag(u, L, R, c);
        return;
    }
    pushdown(u, L, R);
    int M = (L + R) >> 1;
    qjxg(u << 1, L, M, l, r, c);
    qjxg(u << 1 | 1, M + 1, R, l, r, c);
    pushup(u);
}

void upd(int x, int y, ll z){
    while(top[x] != top[y]){
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        qjxg(1, 1, n, dfn[top[x]], dfn[x], z); // 把跳链这段区间先给修改了
        x = fa[top[x]];
    }
    qjxg(1, 1, n, min(dfn[x], dfn[y]), max(dfn[x], dfn[y]), z); // 最后在同一个链上
}

// 路径查询
Res qry(int x, int y) {
    vector<Res> lft, rht;
    while (top[x] != top[y]) {
        if (dep[top[x]] >= dep[top[y]]) {
            lft.push_back(qjcx(1, 1, n, dfn[top[x]], dfn[x]));
            x = fa[top[x]];
        } else {
            rht.push_back(qjcx(1, 1, n, dfn[top[y]], dfn[y]));
            y = fa[top[y]];
        }
    }
    int l = dfn[x], r = dfn[y];
    Res mid;
    if (l <= r) {
        mid = qjcx(1, 1, n, l, r);
        // mid.lcol 连接 x,mid.rcol 连接 y
    } else {
        mid = qjcx(1, 1, n, r, l);
        swap(mid.lcol, mid.rcol); // 交换后,mid.lcol 是 x,mid.rcol 是 y
    }

    // 合并顺序:rht(逆序)+ mid + lft(正序)
    Res res = mid;
    for (int i = rht.size() - 1; i >= 0; --i) {
        res.seg += rht[i].seg - (rht[i].lcol == res.rcol);
        res.rcol = rht[i].rcol;
    }
    for (int i = 0; i < lft.size(); ++i) {
        if(i != lft.size() - 1) res.seg += lft[i].seg - (lft[i + 1].rcol == lft[i].lcol);
        else res.seg += lft[i].seg - (lft[i].lcol == res.lcol);
    }
    return res;
}

int main(){
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
    cin >> n >> m;
    r = 1;
    for(int i = 1; i <= n; i++) cin >> a[i];
    for(int i = 1; i < n; i++){
        int u, v;
        cin >> u >> v;
        mp[u].push_back(v);
        mp[v].push_back(u);
    }
    dfs1(r, 0);
    dfs2(r, 0);
    build(1, 1, n);
    for(int x, y, z; m; --m){
        char ch;
        cin >> ch;
        if(ch == 'C'){
            cin >> x >> y >> z;
            upd(x, y, z); // 表示将树从 x 到 y 结点最短路径上所有节点的值都染色为 z
        }
        else if(ch == 'Q'){
            cin >> x >> y;
            cout << qry(x, y).seg << "\n"; // 表示求树从 x 到 y 结点最短路径上颜色段数量
        }
       

    }

	return 0;
}

树上赋值 区间求和 区间最值

注意,lzy[u] 初始化不能为0的时候,要在build里把每一个u都初始化

点击查看代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1e5 + 10;
const int inf = 2e5 + 10;
int n, m, r;
int fa[maxn], siz[maxn], dep[maxn], wson[maxn], a[maxn];

vector< int > mp[maxn];

void dfs1(int u, int f){
    fa[u] = f;
    siz[u] = 1;
    dep[u] = dep[f] + 1;
    for(auto v : mp[u]){
        if(v == f) continue;
        dfs1(v, u);
        siz[u] += siz[v];
        if(siz[v] > siz[wson[u]]) wson[u] = v; // 确定重孩子
    }
}

int dfn[maxn], vistime, rdfn[maxn], top[maxn];
void dfs2(int u, int Top) {
    dfn[u] = ++vistime; // 确定dfn序,由于dfn序依赖重儿子的确定,所以必须放在dfs2   !!!
    rdfn[vistime] = u; // 反向确定DFS序第 vistime 个结点是 u
    top[u] = Top;
    if(wson[u] == 0) return ;
    dfs2(wson[u], Top);
    for(auto v : mp[u]){
        if(v == fa[u] || v == wson[u]) continue;
        dfs2(v, v);
    }
}

// 线段树部分
ll w[maxn * 4], lzy[maxn * 4], maxx[maxn * 4];
void pushup(int u) {
    w[u] = (w[u << 1] + w[u << 1 | 1]);
    maxx[u] = max(maxx[u << 1], maxx[u << 1 | 1]);
}

void build(int u, int L, int R){
    lzy[u] = inf; // 未正确初始化导致lzy[u] = 0 被当作赋值标记全给赋为0了 !!!!!
    if(L == R) {
        w[u] = a[rdfn[L]]; // 到达叶节点,该区间的点权是DFS序上第L个结点的权值, 即a[rdfn[L]]  !!!
        maxx[u] = a[rdfn[L]];
        lzy[u] = inf;
        return ;
    }
    int M = (L + R) >> 1;
    build(u << 1, L, M); build(u << 1 | 1, M + 1, R);
    pushup(u);
}

bool InRange(int L, int R, int l, int r) {return l <= L && R <= r;}
bool OutofRange(int L, int R, int l, int r) {return r < L || R < l;}

void maketag(int u, int len, ll x){
    lzy[u] = x;
    w[u] = x * len;
    maxx[u] = x; 
}

void pushdown(int u, int L, int R){
    if(lzy[u] == inf) return; // 如果没有懒标记,直接返回
    int M = (L + R) >> 1;
    maketag(u << 1, M - L + 1, lzy[u]);
    maketag(u << 1 | 1, R - M, lzy[u]);
    lzy[u] = inf;
}

ll qjcx(int u, int L, int R, int l, int r){
    if(InRange(L, R, l, r)) return w[u];
    if(OutofRange(L, R, l, r)) return 0;
    pushdown(u, L, R);
    int M = (L + R) >> 1;
    return (qjcx(u << 1, L, M, l, r) + qjcx(u << 1 | 1, M + 1, R, l, r));
}
ll qjcx_max(int u, int L, int R, int l, int r){
    if(InRange(L, R, l, r)) return maxx[u];
    if(OutofRange(L, R, l, r)) return -inf;
    pushdown(u, L, R);
    int M = (L + R) >> 1;
    return max(qjcx_max(u << 1, L, M, l, r) , qjcx_max(u << 1 | 1, M + 1, R, l, r));
}
void ddxg(int u, int L, int R, int pos, ll x) {
    if(L == R) {
        w[u] = x;
        maxx[u] = x;
        return;
    }
    int M = (L + R) >> 1;
    if(pos <= M) ddxg(u << 1, L, M, pos, x);
    else ddxg(u << 1 | 1, M + 1, R, pos, x);
    pushup(u);
}

void qjxg(int u, int L, int R, int l, int r, ll x){
    if(InRange(L, R, l, r)) {
        maketag(u, R - L + 1, x);
        return;
    }
    if(OutofRange(L, R, l, r)) return;
    pushdown(u, L, R);
    int M = (L + R) >> 1;
    qjxg(u << 1, L, M, l, r, x);
    qjxg(u << 1 | 1, M + 1, R, l, r, x);
    pushup(u);
}

void upd(int x, int y, ll z){
    while(top[x] != top[y]){
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        qjxg(1, 1, n, dfn[top[x]], dfn[x], z); // 把跳链这段区间先给修改了
        x = fa[top[x]];
    }
    qjxg(1, 1, n, min(dfn[x], dfn[y]), max(dfn[x], dfn[y]), z); // 最后在同一个链上
}

ll qry_sum(int x, int y){
    ll res = 0;
    while(top[x] != top[y]){
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        res += qjcx(1, 1, n, dfn[top[x]], dfn[x]); // 注意较深的点 DFS 序大,整条链是从上到下铺展
        x = fa[top[x]];
    }
    res += qjcx(1, 1, n, min(dfn[x], dfn[y]), max(dfn[x], dfn[y]));
    return res;
}
ll qry_max(int x, int y){
    ll res = -inf;
    while(top[x] != top[y]){
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        res = max(res, qjcx_max(1, 1, n, dfn[top[x]], dfn[x])); // 注意较深的点 DFS 序大,整条链是从上到下铺展
        x = fa[top[x]];
    }
    res = max(res, qjcx_max(1, 1, n, min(dfn[x], dfn[y]), max(dfn[x], dfn[y])));
    return res;
}

int main(){
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
    cin >> n;
    r = 1;

    for(int i = 1; i < n; i++){
        int u, v;
        cin >> u >> v;
        mp[u].push_back(v);
        mp[v].push_back(u);
    }

    for(int i = 1; i <= n; i++) cin >> a[i];

    dfs1(r, 0);
    dfs2(r, r);
    build(1, 1, n);
    int q;
    cin >> q;

    for(int x, y, z; q; --q){
        string op;
        cin >> op;
        if(op == "CHANGE"){
            cin >> x >> z;
            ddxg(1, 1, n, dfn[x], z);
        }
        else if(op == "QMAX"){
            cin >> x >> y;
            cout << qry_max(x, y) << "\n";
        }
        else if(op == "QSUM"){
            cin >> x >> y;
            cout << qry_sum(x, y) << endl;

        }
    }

	return 0;
}


posted @ 2025-08-22 11:48  [丘李]Chilllee  阅读(5)  评论(0)    收藏  举报