【算法】树分治

1. 算法简介

树分治(Tree division),是处理树上路径类问题的算法。树分治又可以分为点分治边分治

考虑这样一个问题:给定一棵有 \(n\) 个点的树,询问树上距离为 \(k\) 的点对是否存在。

暴力的做法就是枚举两个点然后计算距离,统计答案。这样显然 \(O(n^2)\) 的。

我们发现瓶颈在于枚举的过程:我们希望快速地知道树上的路径信息,而不在乎路径上的端点。

这时候就需要使用树分治算法来优化时间。

2. 点分治

点分治是树分治的一种。

大家可能看出来了,上述的例题就是 P3806 【模板】点分治 1

对于一棵树而言,树上的路径无外乎两种:一种是经过根节点的,另一种是不经过根节点的。(前提是有根树,无根树可转为有根树)

对于经过根节点的路径,想要知道其路径信息是很容易的。但不经过根节点的路径就很难维护了(即所在子树相同)。以当前为根的树很难维护其子树路径的信息。这时候我们便可以删去当前根节点,分裂成许多以儿子节点为子树根的新树。

分裂之前:

image

分裂之后:

image

由于每一个节点都当过根节点,这样,树上的所有路径等能被统计到。

但我们会发现,当树为链时,分治的时间复杂度依然为 \(O(n)\) ,没有达到优化时间的目的。

这是,按照原来老老实实从根节点开始分治的方法已经不适用,这时候我们需要找到一个合适的点,使得分治之后,时间复杂度趋近于 \(O(\log n)\)

这个点就是树的重心

2.1 求重心

树的中心定义为:其所有的子树中最大的子树节点数最少。当删去此点时,生成的多棵新树会趋于平衡,这也会让点分治的时间复杂度趋于 \(O(\log n)\)

找中心,便一边 \(dfs\) 整棵树即可。

\(maxs_x\) 表示 \(x\) 节点的最大的子树大小,\(siz_x\) 表示 \(x\) 节点的子树大小,\(rt\) 为选出的根。

考虑一个节点在原树上的位置。值得注意的是,当此节点不为根节点时,其子树包括其父辈之外的所有节点,像这样:

image

蓝色圈出部分为 'now' 节点的所有子树。

Code:

void getrt(int x, int fa) {
  siz[x] = 1, maxs[x] = 0;
  for (int i = h[x]; i; i = e[i].nx) {
    int y = e[i].v;
    if(y == fa || vis[y]) continue;
    getrt(y, x);
    siz[x] += siz[y];
    if(maxs[x] < siz[y]) maxs[x] = siz[y]; 
  }
  maxs[x] = max(maxs[x], sum - siz[x]);
  if(maxs[rt] > maxs[x]) rt = x;
}

注意!!!!,当分裂成多个子树之后,分治到子树时,则需重新找重心。来保证程序的时间复杂度。

由于重新找重心是在分治的过程中完成的,故总时间复杂度不会超过 \(O(n\log n)\)

2.2 分治

找到了重心之后,便可以以重心为根进行分治。

\(vis_x\) bool 数组表示 \(x\) 节点是否被“删除”(删除的节点不能被再次遍历,也不能再次进行答案统计)。

由于每次进入下一层分治时要重新找重心,故分治时要及时把 \(maxs_rt\) 设置为 \(siz_y\)(最大不会超过 \(siz_y\)),根节点编号也要设置为 \(0\)

找完之后,以新重心为根,继续分治,统计答案;

Code:

void divide(int x) {
  vis[x] = f[0] = 1;
  solve(x);
  for (int i = h[x]; i; i = e[i].nx) {
    int y = e[i].v;
    if(vis[y]) continue;//遍历过的点包含其父亲节点
    maxs[rt = 0] = sum = siz[y];
    getrt(y, 0);
    divide(rt);
  }
}

当你在统计答案时想要使用 \(siz\) 时,直接在 'getrt(y, 0);' 补一句 'getrt(rt, 0)' 即可

2.3 答案统计

这里因题而异,拿模板题举例。

\(f_x\) 表示 \(x\) 在当前状态内是否出现,\(tmp_i\) 存储有多少种不同的路径长度。

先可以把新树的点到根的距离求出来:

Code:

void getdis(int x, int fa) {
  tmp[++cnt] = dis[x];
  for (int i = h[x]; i; i = e[i].nx) {
    int y = e[i].v;
    if(y == fa || vis[y]) continue;
    dis[y] = dis[x] + e[i].w;
    getdis(y, x);
  }
}

然后统计答案

Code:

void solve(int x) {
  int H = 1, t = 0;
  for (int i = h[x]; i; i = e[i].nx) {
    int y = e[i].v;
    if(vis[y]) continue;
    dis[y] = e[i].w;
    cnt = 0;
    getdis(y, x);
    For(j,1,cnt) {
      For(k,1,m) {
        if(Q[k] >= tmp[j]) ans[k] |= f[Q[k] - tmp[j]]; 
      }
    }
    For(j,1,cnt) {
      q[++t] = tmp[j];
      f[tmp[j]] = 1;
    }
  }
  while(H <= t) {
    f[q[H]] = 0;
    H++;
  }
}

3. 点分治例题

3.1 P4178 Tree

Proble

给定一棵有 \(n\) 个节点的树,每条边有边权,求出树上两点距离小于等于 \(k\) 的点对数量。

Solve

点分治模板。

考虑统计答案时,令路径长度数组为 \(tmp\),当新出现边权 \(tmp_x\) 时,看已有的边权里是否出现 \(tmp_y\) 使得 \(tmp_x + tmp_y \le k\)。推到得 \(tmp_y <= k - tmp_x\),所以将区间 \([0,k-tmp_x]\) 计入答案,再单点 \(tmp_x\) 加一即可。

树状数组可维护。

Code

#include <bits/stdc++.h>
#define int long long
#define rint register int
#define For(i,l,r) for(rint i=l;i<=r;++i)
#define FOR(i,r,l) for(rint i=r;i>=l;--i)
#define MOD 1000003
#define mod 1000000007
#define inf 0x3f3f3f3f3f3f3f3f

using namespace std;

namespace Read {
  template <typename T>
  inline void read(T &x) {
    x=0;T f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
    x*=f;
  }
  template <typename T, typename... Args>
  inline void read(T &t, Args&... args) {
    read(t), read(args...);
  }
}

using namespace Read;

void print(int x){
  if(x<0){putchar('-');x=-x;}
  if(x>9){print(x/10);putchar(x%10+'0');}
  else putchar(x+'0');
  return;
}

const int N = 4e4 + 10;

struct Node {
  int v, w, nx;
} e[N << 1];

int n, h[N], k, tot, rt, sum, siz[N], maxs[N], t[N], dis[N], tmp[N], cnt, ans;

bool vis[N];

void add(int u, int v, int w) {
  e[++tot] = (Node) {v, w, h[u]};
  h[u] = tot;
}

void getrt(int x, int fa) {
  siz[x] = 1, maxs[x] = 0;
  for (int i = h[x]; i; i = e[i].nx) {
    int y = e[i].v;
    if(y == fa || vis[y]) continue;
    getrt(y, x);
    siz[x] += siz[y];
    if(maxs[x] < siz[y]) maxs[x] = siz[y]; 
  }
  maxs[x] = max(maxs[x], sum - siz[x]);
  if(maxs[rt] > maxs[x]) rt = x;
}

void getdis(int x, int fa) {
  tmp[++cnt] = dis[x];
  for (int i = h[x]; i; i = e[i].nx) {
    int y = e[i].v;
    if(y == fa || vis[y]) continue;
    dis[y] = dis[x] + e[i].w;
    getdis(y, x);
  }
}

int lb(int x) {
  return x & -x;
}

int qry(int x) {
  int Ans = 0;
  for (int i = x; i; i -= lb(i)) {
    Ans += t[i];
  }
  return Ans;
}

void upd(int x, int z) {
  for (int i = x; i <= k + 1; i += lb(i)) {
    t[i] += z;
  }
}

void solve(int x) {
  for (int i = h[x]; i; i = e[i].nx) {
    int y = e[i].v;
    if(vis[y]) continue;
    dis[y] = e[i].w;
    cnt = 0;
    getdis(y, x);
    For(j,1,cnt) {
      if(tmp[j] <= k) ans += qry(k - tmp[j] + 1);
    }
    For(j,1,cnt) {
      if(tmp[j] <= k) upd(tmp[j] + 1, 1);
    }
  }
  memset(t, 0, sizeof t);
  upd(1, 1);
}

void divide(int x) {
  vis[x] = 1;
  solve(x);
  for (int i = h[x]; i; i = e[i].nx) {
    int y = e[i].v;
    if(vis[y]) continue;
    maxs[rt = 0] = n; sum = siz[y];
    getrt(y, 0);
    divide(rt);
  }
}

signed main() {
  read(n);
  For(i,1,n-1) {
    int u, v, w;
    read(u, v, w);
    add(u, v, w);
    add(v, u, w);
  }
  read(k);
  maxs[0] = sum = n;
  getrt(1, 0);
  upd(1, 1);
  divide(rt);
  cout << ans << '\n';
  return 0;
}

3.2 P4149 [IOI2011] Race

Problem

给一棵树,每条边有权。求一条简单路径,权值和等于 \(k\),且边的数量最小。

Solve

点分治模板

在模板题的基础上多加一个记录深度,每次记录深度时取最大值,然后找到权值之和为 \(k\) 的路径就用深度更新答案。

\(k=tmp_j\) 时,无需再找路径进行拼接,直接更新。

存贮时要判断 \(tmp_j\le k\),直接存 \(tmp_j\) 可能会爆掉。

Code

#include <bits/stdc++.h>
#define ll long long
#define rint register int
#define For(i,l,r) for(rint i=l;i<=r;++i)
#define FOR(i,r,l) for(rint i=r;i>=l;--i)
#define MOD 1000003
#define mod 1000000007
#define inf 0x3f3f3f3f

using namespace std;

namespace Read {
  template <typename T>
  inline void read(T &x) {
    x=0;T f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
    x*=f;
  }
  template <typename T, typename... Args>
  inline void read(T &t, Args&... args) {
    read(t), read(args...);
  }
}

using namespace Read;

void print(int x){
  if(x<0){putchar('-');x=-x;}
  if(x>9){print(x/10);putchar(x%10+'0');}
  else putchar(x+'0');
  return;
}

const int N = 2e5 + 10, M = 2e6 + 10; 

struct Node {
  int v, w, nx;
} e[N << 1];

int n, k, h[N], tot, maxs[N], f[M], tmp[M], rt, sum, cnt, siz[N], ans = inf, q[M], dep[M], dis[N];

bool vis[N];

void add(int u, int v, int w) {
  e[++tot] = (Node){v, w, h[u]};
  h[u] = tot;
}

void getrt(int x, int fa) {
  maxs[x] = 0, siz[x] = 1;
  for (int i = h[x]; i; i = e[i].nx) {
    int y = e[i].v;
    if(y == fa || vis[y]) continue;
    getrt(y, x);
    siz[x] += siz[y];
    if(maxs[x] < siz[y]) maxs[x] = siz[y];
  }
  maxs[x] = max(maxs[x], sum - siz[x]);
  if(maxs[rt] > maxs[x]) rt = x; 
}

void getdis(int x, int fa, int dp) {
  tmp[++cnt] = dis[x];
  dep[cnt] = dp;
  for (int i = h[x]; i; i = e[i].nx) {
    int y = e[i].v;
    if(y == fa || vis[y]) continue;
    dis[y] = dis[x] + e[i].w;
    getdis(y, x, dp + 1);
  }
}

void solve(int x) {
  int H = 1, t = 0;
  for (int i = h[x]; i; i = e[i].nx) {
    int y = e[i].v;
    if(vis[y]) continue;
    dis[y] = e[i].w;
    cnt = 0;
    getdis(y, x, 1);
    For(j,1,cnt) {
      if(k >= tmp[j] && f[k - tmp[j]] != inf) {
        ans = min(ans, dep[j] + f[k - tmp[j]]);
      }
      if(k == tmp[j]) {
        ans = min(ans, dep[j]);
      }
    }
    For(j,1,cnt) {
      if(k >= tmp[j]) {
		f[tmp[j]] = min(f[tmp[j]], dep[j]);
      	q[++t] = tmp[j];
	  }
    }
  }
  while(H <= t) {
    f[q[H]] = inf; H++;
  }
}

void divide(int x) {
  vis[x] = 1, f[0] = inf;
  solve(x);
  for (int i = h[x]; i; i = e[i].nx) {
    int y = e[i].v;
    if(vis[y]) continue;
    maxs[rt = 0] = n, sum = siz[y];
    getrt(y, 0);
    divide(rt);
  }
}

signed main() {
  read(n, k);
  For(i,1,n-1) {
    int u, v, w;
    read(u, v, w);
    u++, v++;
    add(u, v, w);
    add(v, u, w);
  }
  memset(f, 0x3f, sizeof f);
  maxs[0] = sum = n;
  getrt(1, 0);
  divide(rt);
  if(ans != inf) cout << ans << '\n';
  else puts("-1");
  return 0;
}
posted @ 2024-02-04 15:51  Daniel_yzy  阅读(28)  评论(0编辑  收藏  举报
Title