树分治总结

Empty

本文同步发表在 博客园

静态树分治

这个算法可以处理关于树上所有链信息的问题,由于枚举所有链的复杂度绝对是 \(O(n^2)\)。所以需要更优秀的算法,可以通过合并两条链的方式枚举所有链。

树分治是一个暴力数据结构。

它可以在 \(O(n \log n)\) 的时间里遍历所有的链。
下面的算法可以处理树形态不变的情况下的问题。

边分治

找一条边,把树分成两半。在两边子树找两点,则有一条过这条边的路径。

再对两个子树分治。

可是菊花图会被卡成 \(n^ 2\)

使用 “左二子,右兄弟” 的方法可以优化。

可是也有局限性(我并不是太了解边分治)。

点分治。

相比于边分治,点分治更通用一些,后面的点分树也是由此转化得来。

我们每次找到一个点 \(u\),处理经过 \(u\) 的点的答案,处理子树之间的路径。

![[1.png]]

删去 \(u\)

再对 \(u\) 的子树分治。

这样可以 “遍历” 所有路径。

每次分治的中心选取 重心 是最优的(每次子树大小除以 \(2\))。

于是时间复杂度是 \(O(n \log n)\) 的。

所以我们可以每次遍历所有分治中心的子树,暴力计算答案

一般有两种方式计算答案。

  1. 容斥。用所有点两两的答案减去子树内的答案。

  2. 数据结构维护。每次用遍历当前子树,用数据结构查询之前子树的答案,在把当前字树加入数据结构。

代码实例:

int sze[400010], dp[400010], vis[400010]; // vis:是否当做过分治中心
int tot, rt;
void getrt(int u, int fa) {
    sze[u] = 1, dp[u] = 0;
    for (auto y : edge[u]) {
        int v = y.first;
        if (vis[v] || v == fa)
            continue;
        getrt(v, u);
        sze[u] += sze[v];
        dp[u] = max(dp[u], sze[v]);
    }
    dp[u] = max(dp[u], sum - sze[u]);
    if (dp[u] < dp[rt])
        rt = u;
}

void solve(int u) {
    vis[u] = 1;
    ans += getans(u); // 以容斥计算答案为例
    for (auto y : edge[u]) {
        int v = y.first, w = y.second;
        if (vis[v])
            continue;
        ans -= getans(v); // 
        sum = sze[v];
        dp[0] = n, rt = 0;
        getrt(v, u);
        solve(rt);
    }
}

好了,至此我们已经学会了点分治,下面要开始实战了。

先来一道最基础的题:

例1.1 Tree

给定大小为 \(n\)带权 树。

求长度小于等于 \(k\) 的路径条数。

\(n \le 10^5\)

Sol:

设当前的分治中心为 \(u\)
答案要加上:\(\sum_{subtree(x) \neq subtree(y)} [dis(x) + dis(y) \le k]\)

考虑把和式 \(x, y\) 的限制去掉。

它等于所有的 \(x, y\),减去 \(x, y\) 在同一子树的答案。

\(\sum_{x, y}[dis(x) + dis(y) \le k] - \sum_{subtree(x) = subtree(y)} [dis(x) + dis(y) \le k]\)

可以用容斥算。

对于 \(\sum_{x, y}[dis(x) + dis(y) \le k]\) 这样的式子。可以用双指针维护。

看代码吧:

void getrt(int u, int fa) {
    sze[u] = 1, dp[u] = 0;
    for (auto y : edge[u]) {
        int v = y.first;
        if (vis[v] || v == fa)
            continue;
        getrt(v, u);
        sze[u] += sze[v];
        dp[u] = max(dp[u], sze[v]);
    }
    dp[u] = max(dp[u], sum - sze[u]);
    if (dp[u] < dp[rt])
        rt = u;
}
void getdis(int u, int fa) {
    rev[++tot] = d[u];
    for (auto y : edge[u]) {
        int v = y.first, w = y.second;
        if (!vis[v] && v != fa) {
            d[v] = d[u] + w;
            getdis(v, u);
        }
    }
}// 算距离
int getans(int u, int w) {
    tot = 0, d[u] = w;
    getdis(u, 0);
    sort(rev + 1, rev + tot + 1);
    int l = 1, r = tot, tmp = 0;
    while (l <= r) {
        if (rev[l] + rev[r] <= k)
            tmp += r - l, ++l;
        else
            r--;
    }
    return tmp;
}// 得到 rev 数组里两两的答案
void solve(int u) {
    vis[u] = 1;
    ans += getans(u, 0);// 加上所有的点的答案。
    for (auto y : edge[u]) {
        int v = y.first, w = y.second;
        if (vis[v])
            continue;
        ans -= getans(v, w); // 减去子树内的答案,得到两两子树的答案
        sum = sze[v];
        dp[0] = n, rt = 0;
        getrt(v, u);
        solve(rt);
    }
}

练1.1.1 聪聪可可

P2634 [国家集训队] 聪聪可可

题意:给你一颗 带权 树,求有多少条路径满足长度是 \(3\) 的倍数。

sol:这算是刚才那道题改了一下(更简单了)。

还是利用容斥。

一样的,答案为:\(\sum_{x, y}[3 | dis(x) + dis(y) ] - \sum_{subtree(x) = subtree(y)} [3 | dis(x) + dis(y)]\)

容易计算。

代码

#include <bits/stdc++.h>
using namespace std;
const int N = 20010;
#define int long long
vector<pair<int, int> > edge[N];
int n;
int sze[N], dp[N];
int vis[N];
int d[N];
int tot, rt, sum;
int rev[N];
int ans;
void getrt(int u, int fa) {
    sze[u] = 1, dp[u] = 0;
    for (auto y : edge[u]) {
        int v = y.first;
        if (vis[v] || v == fa)
            continue;
        getrt(v, u);
        sze[u] += sze[v];
        dp[u] = max(dp[u], sze[v]);
    }
    dp[u] = max(dp[u], sum - sze[u]);
    if (dp[u] < dp[rt])
        rt = u;
}
void getdis(int u, int fa) {
    rev[++tot] = d[u];
    for (auto y : edge[u]) {
        int v = y.first, w = y.second;
        if (!vis[v] && v != fa) {
            d[v] = d[u] + w;
            getdis(v, u);
        }
    }
}
int getans(int u, int w) {
    tot = 0, d[u] = w;
    getdis(u, 0);
    int s[3] = { 0, 0, 0 };
    for (int i = 1; i <= tot; i++) {
        s[rev[i] % 3]++;
    }

    return s[0] * s[0] + s[1] * s[2] * 2;
}
void solve(int u) {
    vis[u] = 1;
    ans += getans(u, 0);
    for (auto y : edge[u]) {
        int v = y.first, w = y.second;
        if (vis[v])
            continue;
        ans -= getans(v, w);
        sum = sze[v];
        dp[0] = n, rt = 0;
        getrt(v, u);
        solve(rt);
    }
}

int gcd(int a, int b) { return (b == 0) ? a : gcd(b, a % b); }

signed main() {
    cin >> n;
    for (int i = 1; i < n; i++) {
        int u, v, w;
        cin >> u >> v >> w;
        edge[u].push_back({ v, w });
        edge[v].push_back({ u, w });
    }
    dp[0] = sum = n;

    getrt(1, 0);
    solve(rt);
    cout << ans / gcd(ans, n * n) << "/" << n * n / gcd(ans, n * n);
    return 0;
}

练1.1.2 重建计划

(P4292 [WC2010] 重建计划)[https://www.luogu.com.cn/problem/P4292]

题意:

给你一棵带权树:求一条边数在 \([L, R]\) 路径,使权值和除以边数最大,

设边权为 \(w_1, w_2, \cdots, w_k\)

答案为:$$\frac{w_1 + w_2 + \cdots w_k}{k}$$。

二分答案 \(x\),要使得:

\[\frac{w_1 + w_2 + \cdots w_k}{k} \ge x \]

即 $${w_1 + w_2 + \cdots w_k} \ge kx$$

即 $${(w_1 - x) + (w_2 - x) + (w_3 - x) \cdots (w_k - x)} \ge 0$$

每次二分先把边权减去 \(x\)

我们要求最长的一条路径的长度 \(\ge 0\)

还是点分治。每次用线段树维护边数在一段区间的答案,可是这样做的时间复杂度为 \(O(n \log^3 n)\)

注意到每次查询的区间为 \([L - dep, R - dep]\)

按照子树内 \(dep\) 从大到小排序的话,可以用单调队列维护,查询完答案后与原数组的值取 \(\max\)

这样做还会有问题,单调队列的大小由最深的子树决定。最坏复杂度为 \(n^2 \log n\)

所以我们得按照子树深度从小到大的顺序枚举。

时间复杂度 \(O(n \log^2 n)\)

代码:
这里放的是同学的代码,因为我是用 \(n \log^3 n\) 的算法卡过去的。
见谅。


#include <bits/stdc++.h>
using namespace std;

typedef int ll;
typedef double ld;

const ll Pig = 2e5 + 10;

ll n, L, R, dis[Pig], cnt[Pig], ans, sze[Pig], p[Pig], cur, ln[Pig], len;
ld d[Pig], buc[Pig], val;
vector<pair<ll, ll> > g[Pig];
vector<ll> pnt;
bitset<Pig> vis;

void dfs_ln(ll i, ll f, ll t) {
    dis[i] = dis[f] + 1;
    ln[t] = max(ln[t], dis[i]);

    for (auto j : g[i]) {
        if (j.first == f or vis[j.first])
            continue;

        dfs_ln(j.first, i, t);
    }
}

void dfs1(ll i, ll f) {
    sze[i] = 1;
    p[i] = 0;

    for (auto j : g[i]) {
        if (j.first == f or vis[j.first])
            continue;

        ll v = j.first, w = j.second;
        d[v] = d[i] + w - val;
        dis[v] = dis[i] + 1;
        dfs1(v, i);
        sze[i] += sze[v];
        p[i] = max(p[i], sze[v]);
    }
}

void dfs2(ll i) {
    if (ans)
        return;

    vector<ll> v, curr;
    vector<pair<ll, ll> > gg;
    buc[0] = 0;
    dis[i] = 1;
    vis[i] = 1;
    len = 0;

    for (auto j : g[i]) {
        if (!vis[j.first]) {
            ln[j.first] = 0;
            dfs_ln(j.first, i, j.first);
            gg.emplace_back(j);
        }
    }

    sort(gg.begin(), gg.end(), [&](pair<ll, ll> a, pair<ll, ll> b) { return ln[a.first] < ln[b.first]; });

    for (auto j : gg) {
        cur = j.first;
        d[cur] = j.second - val;
        dis[j.first] = 1;
        dfs1(cur, i);
        queue<ll> q;
        deque<ll> c;
        vector<ll> point;
        q.emplace(cur);
        ll r = -1;

        while (!q.empty()) {
            ll pt = q.front();
            p[pt] = max(p[pt], sze[j.first] - sze[pt]);

            if (p[pt] < p[cur])
                cur = pt;

            q.pop();
            point.emplace_back(pt);
            for (auto k : g[pt])
                if (dis[k.first] > dis[pt])
                    q.emplace(k.first);
        }

        reverse(point.begin(), point.end());

        for (ll k : point) {
            while (r < R - dis[k] and r < len) {
                r++;
                while (!c.empty() and buc[c.back()] < buc[r]) c.pop_back();
                c.emplace_back(r);
            }
            while (!c.empty() and c.front() + dis[k] < L) c.pop_front();
            if (!c.empty() and buc[c.front()] + d[k] >= 0)
                ans = 1;
        }

        for (ll k : point) buc[dis[k]] = max(buc[dis[k]], d[k]), curr.emplace_back(k), len = max(len, dis[k]);

        v.emplace_back(cur);

        if (ans) {
            vis[i] = 0;
            buc[0] = buc[Pig - 1];

            for (ll k : curr) buc[dis[k]] = buc[Pig - 1];

            return;
        }
    }

    buc[0] = buc[Pig - 1];

    for (ll k : curr) buc[dis[k]] = buc[Pig - 1];

    if (ans) {
        vis[i] = 0;
        return;
    }

    for (ll j : v) dfs2(j);

    vis[i] = 0;
}

ll read() {
    char c = getchar();
    ll res = 0;
    bool flg = 1;

    while (!isdigit(c)) {
        if (c == '-')
            flg = 0;
        c = getchar();
    }

    while (isdigit(c)) res = (res << 3) + (res << 1) + (c ^ '0'), c = getchar();

    if (!flg)
        res = -res;
    return res;
}

int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.setf(ios::fixed);
    cout.precision(3);
    memset(buc, -0x3f, sizeof(buc));
    n = read();
    L = read();
    R = read();
    bool flg = 1;

    for (ll i = 1, u, v, w; i < n; i++) {
        u = read();
        v = read();
        w = read();
        g[u].emplace_back(v, w);
        g[v].emplace_back(u, w);
    }

    ld l = 0, r = 1e6;

    while (fabs(l - r) > 4.5e-4) {
        val = (l + r) / 2;
        ans = 0;
        dfs2(1);

        if (ans)
            l = val;
        else
            r = val;
    }

    cout << (l + r) / 2;
    return 0;
}

练1.1.3 Yin and Yang G

P3085 [USACO13OPEN] Yin and Yang G

题意:给你一棵边权为 \(1\) 或者 \(-1\) 的树。求有多少条路径 \(u \to v\)\(u \ne v\)),满足路径上存在一点 \(p\)\(p\neq n\)\(p \neq m\))使得 $\mathrm{dis(u, p)} = \mathrm{dis(p, m)} = 0 $

sol
等价于求 \(\mathrm{dis(u, v)} = \mathrm{dis(u, p) = 0}\)

设分治中心为 \(r\),子树内的点到它的距离记为 \(d_i\)

假设在它的子树中有两点 \(u, v\),使得 \(d_u + d_v = 0\),那是不是在 \(u\) 或者 \(v\) 的祖先中存在一个点 \(p\) 使得 \(dis_u = dis_p\) 或者 \(dis_v = dis_p\) 才行。即选的点 \(p\)\(dis\) 要和 \(u\) 或者 \(v\) 的值相等。

我们按照一个点 \(u\) 的祖先有没有点和 \(d_u\) 相等把点分成两类。

没有的点只能与有的点配成一对,而有的点可以和两种点配对。

这个用数组统计。需要注意的是要仔细考虑 \(u\)\(r\) 配对的情况。

代码:
为了避免负数,我将下标平移了一段区间,也可以使用 map 解决。

#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 2e5 + 10;
int n;
vector<pair<int, int> > edge[N];
int dp[N], sze[N], vis[N], cnt[N], res[N], cnt1[N];
int rt, sum, ans;

void getrt(int u, int f) {
    dp[u] = 0, sze[u] = 1;
    for (auto y : edge[u]) {
        int v = y.first, w = y.second;
        if (vis[v] || v == f)
            continue;
        getrt(v, u);
        sze[u] += sze[v];
        dp[u] = max(dp[u], sze[v]);
    }
    dp[u] = max(dp[u], sum - sze[u]);
    if (dp[u] < dp[rt])
        rt = u;
}

int mn, mx;
bool flag = 0;
void getans(int u, int f, int dis) {
    ans += (res[dis + 100000] > 0) * cnt[-dis + 100000] + cnt1[-dis + 100000];
    if (dis == 0 && res[dis + 100000] > 1)
        ans++;
    res[dis + 100000]++;//统计当前点到 $r$ 的 $d$
    mn = min(mn, dis);
    mx = max(mx, dis);
    for (auto y : edge[u]) {
        int v = y.first, w = y.second;
        if (vis[v] || v == f)
            continue;
        getans(v, u, dis + w);
    }
    res[dis + 100000]--;
}

void getdis(int u, int f, int dis) {
    cnt[dis + 100000] += (res[dis + 100000] == 0);
    cnt1[dis + 100000] += (res[dis + 100000] > 0);
    res[dis + 100000]++;//统计当前点到 $r$ 的 $d$

    for (auto y : edge[u]) {
        int v = y.first, w = y.second;
        if (vis[v] || v == f)
            continue;
        getdis(v, u, dis + w);
    }
    res[dis + 100000]--;
}

void solve(int u) {
    vis[u] = 1;
    res[100000] = 1;
    mn = 1e5, mx = -1e5;
    flag = 0;
    for (auto y : edge[u]) {
        int v = y.first, w = y.second;
        if (vis[v])
            continue;
        getans(v, u, w);
        getdis(v, u, w);
    }
    for (int i = mn; i <= mx; i++) res[i + 100000] = cnt[i + 100000] = cnt1[i + 100000] = 0;
    for (auto y : edge[u]) {
        int v = y.first, w = y.second;
        if (vis[v])
            continue;

        sum = sze[v];
        rt = 0, dp[rt] = n + 1;
        getrt(v, u);
        solve(rt);
    }
}

signed main() {
    cin >> n;
    for (int i = 1; i < n; i++) {
        int u, v, w;
        cin >> u >> v >> w;
        if (w == 0)
            w = -1;
        edge[u].push_back({ v, w });
        edge[v].push_back({ u, w });
    }
    sum = n + 1;
    rt = 0, dp[rt] = n + 1;
    getrt(1, 0);
    solve(rt);
    cout << ans;
    return 0;
}

广义点分治

传统的点分治帮助我们快速统计树上所有路径。

可以发现分治的层数 不超过 \(\log n\) 层。

于是一类问题出现了,问一棵树上的最优点 \(x\)

怎么刻画最优是题目规定的,如带权重心(P3345 幻想乡战略游戏)。

我们先假设 \(f(x)\) 为这个点的答案。

如果这样的点只有一个,且满足单调性:若这个点的答案不优,则它的子树的答案一定更不优。

所以答案一定在满足 \(f(x) < f(fa_x)\) 的子树里。

我们可以从 \(fa_x\) 跳到 \(x\) 子树的重心。

则最优点一定在点分树的子子树里。

例1.2 快递员

题意:
(不想改题面啦qwq)

Showson 的城市里面有 \(n\) 家快递站,被 \(n - 1\) 条带权无向边相连。

Showson 需要送 \(m\) 个快递,第 \(i\) 个货物需要从 \(u\) 送到 \(v\)。由于 Showson 不能带着货物走太长的路,所以对于一次送货,他需要先从集散中心到 \(u\),再从 \(u\) 回到集散中心,再从集散中心到 \(v\),最后从 \(v\) 返回集散中心。换句话说,如果设集散中心开在 \(c\) 号点,那么他的路径是 \(c \rightarrow u \rightarrow c \rightarrow v \rightarrow c\)

现在 Showson 希望确定一个点作为集散中心的开设位置,使得他送货所需的最长距离最小。显然,这个最长距离是个偶数,你只需要输出最长距离除以 \(2\) 的结果即可。

sol

点分治。

设现在的快递中心为 \(r\)

先暴力求出快递中心到所有点对的距离和。

找到所有使答案最大的点对。

\(r\) 在它们的路径上,则答案无法再小。

若两组点对所在的子树不同,答案也无法再小。

假如可以再小。

就往那个子树分治。

#include<bits/stdc++.h>
using namespace std;
//#define int long long

const int N = 1e5 + 10;
int n, m; 
vector<pair<int, int> > edge[N];

int vis[N], dp[N], sze[N], tot, rt;

struct node{
	int x, y;
}a[N];

void getrt(int u, int f) {
	dp[u] = 0, sze[u] = 1;
	for(auto y : edge[u]) {
		int v = y.first, w = y.second;
		if(vis[v] || v == f) continue;
		getrt(v, u);
		dp[u] = max(dp[u], sze[v]);
		sze[u] += sze[v];
	}
	dp[u] = max(dp[u], tot - sze[u]);
	if(dp[u] < dp[rt]) rt = u;
}
int d[N], sub[N], ans = 1e9;
void getdis(int u, int f, int s) {
	sub[u] = s;
	for(auto y : edge[u]) {
		int v = y.first, w = y.second;
		if(v == f) continue;
		d[v] = d[u] + w;
		getdis(v, u, s);
	}
}

void print() {
	cout << ans;
	exit(0);
}

int solve(int u) {
	if(vis[u]) print();
	vis[u] = 1;
	d[u] = 0;
	for(auto y : edge[u]) {
		int v = y.first, w = y.second;
		d[v] = w;
		getdis(v, u, v);
	}
	int mx = 0;
	vector<int> v;
	for(int i = 1; i <= m; i++) {
		int x = a[i].x, y = a[i].y;
		if(d[x] + d[y] > mx) v.clear(), v.push_back(i), mx = d[x] + d[y];
		else if(d[x] + d[y] == mx) v.push_back(i); 
	}
	if(mx < ans) ans = mx;
	int lst = 0;
	for(auto i : v) {
		int x = a[i].x, y = a[i].y;
		if(sub[x] != sub[y]) print(); 
		if(lst == 0) lst = sub[x];
		if(lst != sub[x]) print();
	}
	rt = 0, dp[0] = n + 1, tot = sze[lst];
	getrt(lst, u);
	solve(rt);
}

signed main() {
	cin >> n >> m;
	for(int i = 1; i < n; i++) {
		int u, v, w; cin >> u >> v >> w;
		edge[u].push_back({v, w});
		edge[v].push_back({u, w});
	}
	for(int i = 1; i <= m; i++) {
		cin >> a[i].x >> a[i].y;
	}
	rt = 0, dp[0] = n + 1, tot = n;
	getrt(1, 0);
	solve(rt);
	print();
	return 0;
}

练2.2.1 幻想乡战略游戏

题意

有一棵带权树。

找到一个点 \(u\),使得 \(\sum_{v = 1}^{n} dis(u, v) \times a_v\) 最小。

Sol

这道题要用点分树。

请学习完点分树以后再来查看。

做法差不多。

\(S(x)\) 表示 \(x\) 点的子树里的点权和。

设当前最优 \(x\) 的答案为 \(res\)\(x\) 为根。

\(y\)\(x\) 的一个儿子。

考略 \(\Delta res_{x \to y} = res + (S(x) - 2S(y)) \times w(x, y)\)

\(\Delta res_{x \to y} < 0 \Leftrightarrow S(x) < 2\times S(y)\)

这样的 \(y\) 至多只有一个。

证明:

假设存在两个 \(y1\)\(y2\),使得 \(S(x) < 2S(y1), S(x) < S(y2)\)

\(2S(x) < 2(S(y1) + S(y2))\),所以 \(S(y1) + S(y2) > S(x)\),不成立。

于是我们可以每次枚举点 \(x\) 的所有儿子,找到最优的那个,对他进行点分治即可。

计算 \(\sum_{v = 1}^{n} dis(u, v) \times a_v\) 直接点分树就行了。

#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 1e5 + 10;
struct node {
    int v, w, rt;
};
int n;
vector<node> edge[N];
int dep[N], top[N], Dis[N], sze[N], son[N], Fa[N];

void dfs1(int u, int f) {
    dep[u] = dep[f] + 1;
    Fa[u] = f;
    sze[u] = 1;
    for (auto y : edge[u]) {
        int v = y.v, w = y.w;
        if (v == f)
            continue;
        Dis[v] = Dis[u] + w;
        dfs1(v, u);
        sze[u] += sze[v];
        if (sze[v] > sze[son[u]])
            son[u] = v;
    }
}

void dfs2(int u, int tp) {
    top[u] = tp;
    if (son[u])
        dfs2(son[u], tp);
    for (auto y : edge[u]) {
        int v = y.v;
        if (v != Fa[u] && v != son[u])
            dfs2(v, v);
    }
}

int lca(int u, int v) {
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]])
            swap(u, v);
        u = Fa[top[u]];
    }
    if (dep[u] < dep[v])
        return u;
    return v;
}

int dis(int u, int v) { return Dis[u] + Dis[v] - 2 * Dis[lca(u, v)]; }

int Mx[N], vis[N], rt, cnt;

void getrt(int u, int f) {
    sze[u] = 1, Mx[u] = 0;
    for (auto y : edge[u]) {
        int v = y.v, w = y.w;
        if (v == f || vis[v])
            continue;
        getrt(v, u);
        sze[u] += sze[v], Mx[u] = max(Mx[u], sze[v]);
    }
    Mx[u] = max(Mx[u], cnt - sze[u]);
    if (Mx[u] < Mx[rt])
        rt = u;
}

int fa[N];  // 点分树父亲

void init(int u) {
    vis[u] = 1;
    for (auto &y : edge[u]) {
        int v = y.v, w = y.w;
        if (vis[v])
            continue;
        rt = 0, Mx[0] = n + 1, cnt = sze[v];
        getrt(v, u);
        y.rt = rt;
        fa[rt] = u;
        init(rt);
    }
}

int f1[N], f2[N], sum[N];

void modify(int x, int val) {
    for (int u = x; u; u = fa[u]) {
        sum[u] += val;
    }
    for (int u = x; fa[u]; u = fa[u]) {
        int D = dis(fa[u], x);
        f1[fa[u]] += D * val;
        f2[u] += D * val;
    }
}

int query(int x) {
    int res = f1[x];
    for (int u = x; fa[u]; u = fa[u]) {
        int D = dis(x, fa[u]);
        res += (f1[fa[u]] - f2[u]);
        res += (sum[fa[u]] - sum[u]) * D;
    }
    return res;
}

int solve(int u) {
    int res = query(u);
    for (auto y : edge[u]) {
        int v = y.v;
        if (query(v) < res) {
            return solve(y.rt);
        }
    }
    return res;
}

signed main() {
    cin.tie(0), cout.tie(0);
    ios::sync_with_stdio(0);
    int T;
    cin >> n >> T;
    for (int i = 1; i < n; i++) {
        int u, v, w;
        cin >> u >> v >> w;
        edge[u].push_back({ v, w, 0 });
        edge[v].push_back({ u, w, 0 });
    }
    dfs1(1, 0);
    dfs2(1, 0);
    rt = 0, Mx[0] = n + 1, cnt = n;
    getrt(1, 0);
    int Rt = rt;
    init(rt);

    while (T--) {
        int x, v;
        cin >> x >> v;
        modify(x, v);
        cout << solve(Rt) << "\n";
    }
    return 0;
}

动态树分治

动态树分治,又称点分树。

我会用尽量精简生动的语言将他描述出来。

这个数据结构真的挺难的,不过真的挺有用的。

每次点分治,我们会把这一层的重心与上一层的重心连边。

得到一颗树,称为点分树。

这样原本的父子关系就被完全打乱了。

那这个对我们解决问题有什么帮助呢?

有些问题我们并不关心树的形态,比如并查集,联通块问题。我们要求出两点间的路径,也不一定要求出 \(LCA\)。我们可以找一个分割点 \(p\),把路径分为 \(u \to p\)\(p \to v\)

点分树就是对原树做了这样的映射。

点分树有如下性质:

  1. 它的 深度\(\log n\),与点分治的层数一样。我们可以枚举点分树上的的所有父亲。甚至可以开一个 vector 存下每个点子树内的点。

  2. 对两点 \((u, v)\),它们在点分树上的 \(lca\).一定在它们的路径上, 也就是说, \(dis(u, v) = dis(u, lca) + dis(lca, v)\)。注意 \(dis\) 是在原树上的距离。

计算贡献

以下用 \(fa_x\) 表示 \(x\) 在点分树上的父亲节点,\(subtree(x)\) 表示 \(x\) 在点分树上的子树节点集合,\(A(x)\) 表示 \(x\) 的所有祖先节点集合,\(dis(x, y)\) 表示两点在 原树上的距离

枚举所有祖先节点当做中转点。

\(ans(i, j)\) 表示距离 \(i\) 小于等于 \(j\) 的点的点权和。

设以 \(a\) 为中转点,由于 \(a\) 在点分树的子树里的点已经被统计过了,那么要统计的是除去 \(a\)\(x\) 这侧的子树的所有点到 \(x\) 的距离小于等于 \(j\) 的答案。

\(f1(i, j)\) 表示在 \(i\) 点分树的子树里的点到 \(j\) 的距离小于等于 \(j\) 的点权和。

\[f1(i, j) = \sum_{x \in subtree(i) \land dis(x, i) \le j} a_x \]

为了除去某个点的子树。

\(f2(i, j)\) 表示在 \(i\) 点分树的子树里的点到 \(fa_i\) 的距离小于等于 \(j\) 的点权和。

\[f2(i, j) = \sum_{x \in subtree(i) \land dis(x, fa_i) \le j} a_x \]

于是 \(ans\) 可以计算。

\[ans(i, j) = f1(i, j) + \sum_{x \in A(i) \land fa(x) \land dis(i, x) \le j} f1(fa_x, j - dis(i, fa_x)) - f1(x, j - dis(i, fa_x)) \]

我们看到例题。

例2.1 震波

在一片土地上有 \(n\) 个城市,通过 \(n-1\) 条无向边互相连接,形成一棵树的结构,相邻两个城市的距离为 \(1\),其中第 \(i\) 个城市的价值为 \(value_i\)

不幸的是,这片土地常常发生地震,并且随着时代的发展,城市的价值也往往会发生变动。

接下来你需要在线处理 \(m\) 次操作:

0 x k 表示发生了一次地震,震中城市为 \(x\),影响范围为 \(k\),所有与 \(x\) 距离不超过 \(k\) 的城市都将受到影响,该次地震造成的经济损失为所有受影响城市的价值和。

1 x y 表示第 \(x\) 个城市的价值变成了 \(y\)

为了体现程序的在线性,操作中的 \(x\)\(y\)\(k\) 都需要异或你程序上一次的输出来解密,如果之前没有输出,则默认上一次的输出为 \(0\)

思路:

只需要处理修改操作。

直接暴力在点分树上跳父亲。

看代码吧,注意树状数组细节:

#include<bits/stdc++.h>
using namespace std;
//#define int long long
const int N = 2e5 + 10;
int n, m; 

vector<int > edge[N];
struct BIT{
    int sze;
    vector<int> c;
    void resize(int x) {
        sze = x - 1;
        c.resize(x);
    } 
    void add(int x, int y) {
        x++;
        for(;x <= sze; x += x & -x) c[x] += y;
    } 
    int query(int x) {
        x++;
        int res = 0;
        x = min(x, sze);
        for(;x ; x -= x & -x) res += c[x]; 
        return res;
    }
}w0[N], w1[N];
int fa[N][19], dep[N], lg2[N];

void dfs0(int u, int f) {
    fa[u][0] = f, dep[u] = dep[f] + 1;
    for(auto v : edge[u]) if(v != f) dfs0(v, u);	
}
int Lca(int x, int y) {
    if(dep[x] < dep[y]) swap(x, y);
    while(dep[x] > dep[y]) x = fa[x][lg2[dep[x] - dep[y]]];
    if(x == y) return x;
    for(int i = 18; i >= 0; i--) 
        if(fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i];
    return fa[x][0];
}
int getdis(int x, int y) {
    return dep[x] + dep[y] - 2 * dep[Lca(x, y)];
}

//上面部分为预处理。

int vis[N], tot, rt, dp[N], sze[N];

void getrt(int u, int f) {
    dp[u] = 0, sze[u] = 1;
    for(auto v : edge[u]) {
        if(v == f || vis[v]) continue;
        getrt(v, u);
        sze[u] += sze[v];
        dp[u] = max(sze[v], dp[u]);
    }
    dp[u] = max(dp[u], tot - sze[u]);
    if(dp[u] < dp[rt]) rt = u;
}
int dfa[N], dsze[N];
void init(int u, int f) {
    dfa[u] = f;// 点分树上的父亲
    vis[u] = 1;
    w1[u].resize(tot + 2);
    w0[u].resize(tot + 2);// 注意空间
    for(auto v : edge[u]) {
        if(vis[v]) continue;
        rt = 0, dp[0] = n + 1, tot = sze[v];
        getrt(v, u);
        init(rt, u);
    }
}

int a[N];
void modify(int u, int w) {
    for(int i = u; i; i = dfa[i]) w0[i].add(getdis(u, i), w);
    for(int i = u; dfa[i]; i = dfa[i]) w1[i].add(getdis(u, dfa[i]), w);
}

int query(int u, int k) {
    int res = w0[u].query(k);
    for(int i = u; dfa[i]; i = dfa[i]) {
        int dis = getdis(u, dfa[i]);
        if(k >= dis) res += w0[dfa[i]].query(k - dis) - w1[i].query(k - dis);
    }
    return res; 
}

signed main() {
    cin.tie(0), cout.tie(0);
    ios::sync_with_stdio(0);
    cin >> n >> m;
    for(int i = 1; i <= n; i++) cin >> a[i];
    for(int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        edge[u].push_back(v);
        edge[v].push_back(u); 
    } 
    for(int i = 1; i <= n; i++) lg2[i] = log2(i);
    dfs0(1, 0);
    for(int i = 1; i <= 18; i++) 
        for(int j = 1; j <= n; j++) 
            fa[j][i] = fa[fa[j][i - 1]][i - 1];
    rt = 0, dp[0] = n + 1, tot = n;
    getrt(1, 0);
    init(rt, 0);
    for(int i = 1; i <= n; i++) modify(i, a[i]);
    int ans = 0; 
    while(m--) {
        int op, x, y; cin >> op >> x >> y;
        x ^= ans, y ^= ans;
        if(op == 0) {
            ans = query(x, y);
            cout << ans << "\n";
        }
        else {
            modify(x, y - a[x]), a[x] = y;
        }
    }
    return 0;
}

练 2.1.1 P3241 [HNOI2015] 开店

P3241 [HNOI2015] 开店

\[f1(i, j) = \sum_{x \in subtree(i)\land w_x \le j} dis(i, x) \\ f2(i, j) = \sum_{x \in subtree(i)\land w_x \le j} dis(fa_i, x) \\ g1(i, j) = \sum_{x \in subtree(i)\land w_x \le j} 1 \\ g2(i, j) = \sum_{x \in subtree(i)\land w_x \le j} 1 \]

注意这里是计算距离和,之前只加上了 \(x\)\(fa_i\) 的距离,所以在查询 \(p\) 点时,漏掉了 \(fa_i\)\(p\) 这一段。

\(ans(i, j) = f1(i, j) + \sum_{x \in A(i)} \{f1(fa_i, j) - f2(i, j) + (g1(fa_i, j) - g2(i, j)) \times dis(i, x)\}\)

可以点分树时对每个点开 \(vector\) 记录子树内 \(w\) 的值查询时二分即可。

代码:
二分:


#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 2e5 + 10;
int n, M;

vector<pair<int, int> > edge[N];
int a[N];
struct SGT{
    vector<pair<int, int> > v;
    void change(int x, int y) {
        v.push_back({x, y});
    }
    
    void init() {
        sort(v.begin(), v.end());
        int sze = v.size();
        for(int i = 1; i < sze; i++) {
            v[i].second += v[i - 1].second;
        }
    }
    
    int query(int x) {
        auto p = upper_bound(v.begin(), v.end(), make_pair(x, (int)1e14));
        if(p == v.begin()) return 0;
        p--;
        return (*p).second;
    }
    
    int query(int l, int r) {
        return query(r) - query(l - 1);
    }
}w1[N], w2[N], w3[N];

int Fa[N][20], dep[N], dis[N], lg2[N];

void dfs0(int u, int f) {
    Fa[u][0] = f, dep[u] = dep[f] + 1;
    for(int i = 1; i <= 19; i++) Fa[u][i] = Fa[Fa[u][i - 1]][i - 1];
    for(auto y : edge[u]) {
        int v = y.first, w = y.second;
        if(v != f) dis[v] = dis[u] + w, dfs0(v, u);
    }	
}

int Lca(int x, int y) {
    if(dep[x] < dep[y]) swap(x, y);
    while(dep[x] > dep[y]) x = Fa[x][lg2[dep[x] - dep[y]]];
    if(x == y) return x;
    for(int i = 19; i >= 0; i--) if(Fa[x][i] != Fa[y][i]) x = Fa[x][i], y = Fa[y][i];
    return Fa[x][0];
}
int getdis(int x, int y) {
    return dis[x] + dis[y] - 2 * dis[Lca(x, y)];
}

int dp[N], sze[N], vis[N], s[N], cnt, rt;

void getrt(int u, int f) {
    sze[u] = 1, dp[u] = 0;
    for(auto y : edge[u]) {
        int v = y.first, w = y.second;
        if(v == f || vis[v]) continue;
        getrt(v, u);
        sze[u] += sze[v];
        dp[u] = max(dp[u], sze[v]);
    }
    dp[u] = max(dp[u], cnt - sze[u]);
    if(dp[u] < dp[rt]) rt = u;
}

int fa[N];

void init(int u) {
    vis[u] = 1;
    s[u] = cnt;
    for(auto y : edge[u]) {
        int v = y.first, w = y.second;
        if(vis[v]) continue;
        rt = 0, dp[0] = n + 1, cnt = sze[v];
        getrt(v, u);
        fa[rt] = u;
        init(rt);
    }
}

void modify(int u) {
    for(int i = u; i ; i = fa[i]) w1[i].change(a[u], getdis(u, i));
    for(int i = u; fa[i]; i = fa[i]) w2[i].change(a[u], getdis(u, fa[i]));	
    for(int i = u; i; i = fa[i]) w3[i].change(a[u], 1);
}

int query(int u, int l, int r) {
    int res = w1[u].query(l, r);
    for(int i = u; fa[i] ; i = fa[i]) {
        res = res + w1[fa[i]].query(l, r) - w2[i].query(l, r) + getdis(u, fa[i]) * (w3[fa[i]].query(l, r) - w3[i].query(l, r)); 
    }
    return res;
}


signed main() {
    cin.tie(0), cout.tie(0);
    ios::sync_with_stdio(0);
    int T;
    cin >> n >> T >> M;
    for(int i = 1; i <= n; i++) {
        cin >> a[i]; a[i]++;
    }
    for(int i = 1; i < n; i++) {
        int u, v, w; cin >> u >> v >> w;
        edge[u].push_back({v, w});
        edge[v].push_back({u, w});
    }
    dfs0(1, 0);
    rt = 0, dp[0] = n + 1, cnt = n;
    getrt(1, 0);
    init(rt);
    for(int i = 1; i <= n; i++) lg2[i] = log2(i);
    for(int i = 1; i <= n; i++) modify(i);
    for(int i = 1; i <= n; i++) w1[i].init(), w2[i].init(), w3[i].init();
    int ans = 0;
    while(T--) {
        int u, a, b;
        cin >> u >> a >> b;
        int L = min((a + ans) % M, (b + ans) % M) + 1, R = max((a + ans) % M, (b + ans) % M) + 1;
        cout << (ans = query(u, L, R)) << "\n";
    }
    return 0;
} 

动态开店线段树版,只有结构体有变化

struct SGT {
    vector<pair<int, int> > v;
    void change(int x, int y) { v.push_back({ x, y }); }

    void init() {
        sort(v.begin(), v.end());
        int sze = v.size();
        for (int i = 1; i < sze; i++) {
            v[i].second += v[i - 1].second;
        }
    }

    int query(int x) {
        auto p = upper_bound(v.begin(), v.end(), make_pair(x, (int)1e14));
        if (p == v.begin())
            return 0;
        p--;
        return (*p).second;
    }

    int query(int l, int r) { return query(r) - query(l - 1); }
} w1[N], w2[N], w3[N];

练 2.1.2 P5311 [Ynoi2011] 成都七中

P5311 [Ynoi2011] 成都七中

题意:

给你一棵 \(n\) 个节点的树,每个节点有一种颜色,有 \(m\) 次查询操作。

查询操作给定参数 \(l\ r\ x\),需输出:

将树中编号在 \([l,r]\) 内的所有节点保留,\(x\) 所在连通块中颜色种类数。

每次查询操作独立。

Sol

点分树有性质:树上任意一个联通块,存在一个在点分树上深度最小的点,并且整个联通块都在这个点的子树当中。

证明:

问题变成:在一棵树中,从根节点出发,只经过 \([l,r]\) 范围内的点,可以到达的颜色数。

使用反证法。

假设点分树上最浅的节点叫做 \(p\),联通块中有一个点 \(q\) 不在 \(p\) 的子树内。由于点分树上 \(p\) 的子树也是原树中的一个联通块,并且所有 \(p\) 子树之外的点到达 \(p\) 都必须经过一个比p更浅的节点,所以q到p的路径上最浅的点一定比 \(p\) 浅,而同时这个点一定也在联通块内,这违反了“ \(p\) 是最浅的节点”。

所以我们可以把每一个询问 \(l,r,x\),归在和 \(x\) 在同一个联通块中的点分树上最浅的节点。然后就只需要对点分树上每个点遍历一遍子树,分开来处理。

注意要判断这个点是否能走到点分树上的祖先节点。

我们记录一下每一个节点到(点分树上的)每个祖先的路径上的编号最大的点和编号最小的点,分别设成 \(L\)\(R\)

我们发现,只有对于 \(x\) 节点拥有的一个询问 \((l,r)\) ,只有 \(L\ge l\)\(R \le r\)的节点才能对答案有贡献。

将询问离线一波,第一维排序第二维树状数组维护即可解决。

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10;
int n;
vector<int> edge[N];
int c[N];
void add(int x, int y) {
    for (; x <= n; x += x & -x) c[x] += y;
}

int query(int x) {
    if (!x)
        return 0;
    int res = 0;
    for (; x; x -= x & -x) res += c[x];
    return res;
}

void del(int x) {
    for (; x <= n; x += x & -x) c[x] = 0;
}

struct node {
    int l, r, op;
};

int a[N];
int ans[N];
vector<node> q[N];

int vis[N], dp[N], sze[N], rt, cnt;
void getrt(int u, int f) {
    dp[u] = 0, sze[u] = 1;
    for (auto v : edge[u]) {
        if (v == f || vis[v])
            continue;
        getrt(v, u);
        dp[u] = max(dp[u], sze[v]);
        sze[u] += sze[v];
    }
    dp[u] = max(dp[u], cnt - sze[u]);
    if (dp[u] < dp[rt])
        rt = u;
}
int mn[N], mx[N], col[N];
vector<node> t;
void getres(int u, int f) {
    mn[u] = min(mn[f], u), mx[u] = max(mx[f], u);
    t.push_back({ mn[u], mx[u], -a[u] });
    for (auto x : q[u]) {
        if ((!ans[x.op]) && x.l <= mn[u] && mx[u] <= x.r)
            t.push_back({ x.l, x.r, x.op });
    }
    for (auto v : edge[u]) {
        if (vis[v] || v == f)
            continue;
        getres(v, u);
    }
}

bool cmp(node x, node y) {
    if (x.l != y.l)
        return x.l > y.l;
    return x.op < y.op;
}

void solve(int u) {
    vis[u] = 1;
    t.clear();
    mx[0] = -1e9, mn[0] = 1e9;
    getres(u, 0);
    sort(t.begin(), t.end(), cmp);
    for (auto x : t) {
        if (x.op < 0) {
            x.op *= -1;
            if (!col[x.op])
                add(x.r, 1), col[x.op] = x.r;
            else if (x.r < col[x.op]) {
                add(col[x.op], -1);
                add(x.r, 1);
                col[x.op] = x.r;
            }
        } else {
            ans[x.op] = query(x.r) - query(x.l - 1);
        }
    }
    for (auto x : t) {
        if (x.op < 0)
            del(x.r), col[-x.op] = 0;
    }

    for (auto v : edge[u]) {
        if (vis[v])
            continue;
        rt = 0, dp[0] = n + 1, cnt = sze[v];
        getrt(v, u);
        solve(rt);
    }
}

int main() {
    int T;
    cin >> n >> T;
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
    }
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        edge[u].push_back(v);
        edge[v].push_back(u);
    }
    for (int i = 1; i <= T; i++) {
        int l, r, x;
        cin >> l >> r >> x;
        q[x].push_back({ l, r, i });
    }
    rt = 0, dp[0] = n + 1, cnt = n;
    getrt(1, 0);
    solve(rt);
    for (int i = 1; i <= T; i++) cout << ans[i] << "\n";
    return 0;
posted @ 2025-09-03 23:14  merlinkkk  阅读(22)  评论(0)    收藏  举报