点分治

点分治

模板题引入:点分治1

题目描述

给定一棵有 \(n\) 个点的树,询问树上距离为 \(k\) 的点对是否存在。

输入格式

第一行两个数 \(n,m\)

\(2\) 到第 \(n\) 行,每行三个整数 \(u, v, w\),代表树上存在一条连接 \(u\)\(v\) 边权为 \(w\) 的路径。

接下来 \(m\) 行,每行一个整数 \(k\),代表一次询问。

输出格式

对于每次询问输出一行一个字符串代表答案,存在输出 AYE,否则输出 NAY

样例 #1

样例输入 #1

2 1
1 2 2
2

样例输出 #1

AYE

提示

数据规模与约定

  • 对于 \(30\%\) 的数据,保证 \(n\leq 100\)
  • 对于 \(60\%\) 的数据,保证 \(n\leq 1000\)\(m\leq 50\)
  • 对于 \(100\%\) 的数据,保证 \(1 \leq n\leq 10^4\)\(1 \leq m\leq 100\)\(1 \leq k \leq 10^7\)\(1 \leq u, v \leq n\)\(1 \leq w \leq 10^4\)

点分治介绍

树分治有点分治边分治两种,适合处理大规模树上路径信息问题。

树上的路径可以分为两种:

  1. 经过根节点的路径
  2. 不经过根节点的路径
  • 对于经过根节点的路径

可以预处理出每个点到根的路径,然后 dis[u][v]=dis[u][root]+dis[v][root]

注意排除不合法路径(u,v 在同一棵子树内),先把前面子树中各点到根的距离存入一个队列 q[i],并且开一个布尔数组存入队列中的距离 judge[q[i]],再枚举当前子树中各点到根的距离 dis[j]。若询问距离 kdis[j] 的差存在,即 judge[k-dis[j]] 为真,说明此解合法。

  • 对于不经过根节点的路径

可以对子树不断分治,转化为经过根节点的路径。

如果是一棵平衡树,分治次数为 \(O(log\ n)\),每次分治后跑 \(n\) 个点,询问 \(m\) 次,每次判定答案是 \(O(nm)\),时间复杂度为 \(O(nmlog\ n)\)

如果是一条链,且从链的一端开始分治,分治次数将退化为 \(O(n)\),时间复杂度为 \(O(n^2m)\)

分治前,对每棵子树先找出重心做根即可。

点分治的四步操作:

  1. 找出树的重心做根 get_root()
  2. 求出子树中的各点到根的距离 get_dis()
  3. 对当前树统计答案 calc()
  4. 分治各个子树,重复以上操作 divide()
#include<iostream>
#include<algorithm>
using namespace std;

const int N=10005;
const int INF=10000005;
struct node{int v,w,ne;}e[N<<1];
int h[N],idx; //加边
int del[N],siz[N],mxs,sum,root;//求根
int dis[N],d[N],cnt; //求距离
int ans[N],q[INF],judge[INF];//求路径
int n,m,ask[N];

void add(int u,int v,int w){
  e[++idx].v=v; e[idx].w=w;
  e[idx].ne=h[u]; h[u]=idx;
}
void getroot(int u,int fa){
  siz[u]=1;
  int s=0;
  for(int i=h[u];i;i=e[i].ne){
    int v=e[i].v;
    if(v==fa||del[v])continue;
    getroot(v,u);
    siz[u]+=siz[v];
    s=max(s,siz[v]);
  }
  s=max(s,sum-siz[u]);
  if(s<mxs) mxs=s, root=u;
}
void getdis(int u,int fa){
  dis[++cnt]=d[u];
  for(int i=h[u];i;i=e[i].ne){
    int v=e[i].v;
    if(v==fa||del[v])continue;
    d[v]=d[u]+e[i].w;
    getdis(v,u);
  }
}
void calc(int u){
  judge[0]=1;
  int p=0;
  // 计算经过根u的路径
  for(int i=h[u];i;i=e[i].ne){
    int v=e[i].v;
    if(del[v])continue;
    // 求出子树v的各点到u的距离
    cnt=0;
    d[v]=e[i].w;
    getdis(v,u);
    // 枚举距离和询问,判定答案
    for(int j=1;j<=cnt;++j)
      for(int k=1;k<=m;++k)
        if(ask[k]>=dis[j])
          ans[k]|=judge[ask[k]-dis[j]];
    // 记录合法距离
    for(int j=1;j<=cnt;++j)
      if(dis[j]<INF)
        q[++p]=dis[j], judge[q[p]]=1;
  }
  // 清空距离数组
  for(int i=1;i<=p;++i) judge[q[i]]=0;
}
void divide(int u){
  // 计算经过根u的路径
  calc(u);
  // 对u的子树进行分治
  del[u]=1;
  for(int i=h[u];i;i=e[i].ne){
    int v=e[i].v;
    if(del[v])continue;
    mxs=sum=siz[v];
    getroot(v,0); //求根
    divide(root); //分治
  }
}
int main(){
  scanf("%d%d",&n,&m);
  for(int i=1;i<n;++i){
    int u,v,w;
    scanf("%d%d%d",&u,&v,&w);
    add(u,v,w);add(v,u,w);
  }
  for(int i=1;i<=m;++i)
    scanf("%d",&ask[i]);
  mxs=sum=n;
  getroot(1,0);
  getroot(root,0); //重构siz[]
  divide(root);
  for(int i=1;i<=m;++i)
    ans[i]?puts("AYE"):puts("NAY");
  return 0;
}
  • 本题还可以利用容斥原理统计答案
#include<iostream>
#include<algorithm>
using namespace std;

const int N=10005;
const int INF=10000005;
struct node{int v,w,ne;}e[N<<1];
int h[N],idx; //加边
int del[N],siz[N],mxs,sum,root;//求根
int dis[N],d[N],cnt; //求距离
int ans[N];//求路径
int n,m,ask[N];

void add(int u,int v,int w){
  e[++idx].v=v; e[idx].w=w;
  e[idx].ne=h[u]; h[u]=idx;
}
void getroot(int u,int fa){
  siz[u]=1;
  int s=0;
  for(int i=h[u];i;i=e[i].ne){
    int v=e[i].v;
    if(v==fa||del[v])continue;
    getroot(v,u);
    siz[u]+=siz[v];
    s=max(s,siz[v]);
  }
  s=max(s,sum-siz[u]);
  if(s<mxs) mxs=s, root=u;
}
void getdis(int u,int fa){
  dis[++cnt]=d[u];
  for(int i=h[u];i;i=e[i].ne){
    int v=e[i].v;
    if(v==fa||del[v])continue;
    d[v]=d[u]+e[i].w;
    getdis(v,u);
  }
}
void calc(int u, int w, int sign){
  cnt = 0,d[u] = w;
  getdis(u, 0);//求距离
  sort(dis+1, dis+cnt+1);
  for(int i=1;i<=m;i++){
    int l=1, r=cnt;
    while(l < r){
      if(dis[l]+dis[r]<=ask[i]){
        if(dis[l]+dis[r]==ask[i])ans[i]+=sign;
        ++l;
      }
      else --r;
    }
  }
}
void divide(int u){
  calc(u, 0, 1); //求答案
  del[u] = 1;
  for(int i=h[u];i;i=e[i].ne){
    int v = e[i].v;
    if(del[v]) continue;
    calc(v, e[i].w, -1); //容斥
    mxs =sum = siz[v];
    getroot(v, u); //求根
    divide(root);  //分治
  }
}
int main(){
  scanf("%d%d",&n,&m);
  for(int i=1;i<n;++i){
    int u,v,w;
    scanf("%d%d%d",&u,&v,&w);
    add(u,v,w);add(v,u,w);
  }
  for(int i=1;i<=m;++i)
    scanf("%d",&ask[i]);
  mxs=sum=n;
  getroot(1,0);
  getroot(root,0); //重构siz[]
  divide(root);
  for(int i=1;i<=m;++i)
    ans[i]?puts("AYE"):puts("NAY");
  return 0;
}

例题:树上括号序列

题目描述

给定一棵树,树上每个节点的字符可以是 ( 或者 )

求有序点对 的数量,使得 到 的路径上的点的字符(包括 和 )所构成的括号序列合法。

输入格式

第一行一个整数 ,表示节点数量。

第二行一个字符串,由 () 组成,表示每个节点的字符。

接下来 行,每行 个整数 和 ,表示树上的一条边。

输出格式

一行一个整数,表示点对的数量。

样例

输入1

4
(())
1 2
2 3
3 4

输出1

2

样例解释

合法的点对有 \((1,4)\)\((2,3)\)

输入2

5
())((
1 2
2 3
2 4
3 5

输出2

3

样例解释

合法的点对有 \((1,2)\)\((4,2)\)\((5,3)\)

输入3

7
)()()((
1 2
1 3
1 6
2 4
4 5
5 7

输出3

6

数据范围与提示

子任务编号 \(n\le\) 特殊性质 分值
\(1\) \(10^3\) \(10\)
\(2\) \(3\times 10^5\) 树是一条链 \(30\)
\(3\) \(3\times 10^5\) \(60\)
#include <iostream>
using namespace std;
using LL = long long;
const int N = 3e5 + 7, INF = 0x3F3F3F3F;
int n, cnt, hd[N], rt, tot, sz[N], mx;
char s[N];
struct Edge {
    int v, nx;
} eg[N << 1];
void addE(int u, int v, int c) { eg[c] = { v, hd[u] }, hd[u] = c; }
inline void getmx(int &x, int y) {
    if (x < y)
        x = y;
}
inline void getmn(int &x, int y) {
    if (x > y)
        x = y;
}
bool vis[N];
void getrt(int u, int fa) {
    sz[u] = 1;
    int cur = 0;
    for (int i = hd[u]; i; i = eg[i].nx) {
        int v = eg[i].v;
        if (vis[v] || v == fa)
            continue;
        getrt(v, u);
        getmx(cur, sz[v]);
        sz[u] += sz[v];
    }
    getmx(cur, tot - sz[u]);
    if (cur < mx)
        mx = cur, rt = u;
}
// tot1[i] 左括号比右括号多i个的前缀数量, tot2[i] 右括号比左括号多i的后缀数量
int up, cnt1[N], cnt2[N], tot1[N], tot2[N];
void dfs(int u, int fa, int s1, int s2, int mn1, int mn2) {
    sz[u] = 1;
    int t = (s[u] == '(' ? 1 : -1);
    s1 += t, s2 -= t, mn1 = min(mn1, 0) + t, mn2 = min(mn2, 0) - t;
    if (mn1 >= 0)
        ++cnt1[s1], getmx(up, s1);
    if (mn2 >= 0)
        ++cnt2[s2], getmx(up, s2);
    for (int i = hd[u]; i; i = eg[i].nx) {
        int v = eg[i].v;
        if (vis[v] || v == fa)
            continue;
        dfs(v, u, s1, s2, mn1, mn2);
        sz[u] += sz[v];
    }
}
LL ans;
void calc() {
    for (int i = 0; i <= up; ++i) {
        ans -= (LL)cnt1[i] * cnt2[i];
        tot1[i] += cnt1[i], tot2[i] += cnt2[i];
        cnt1[i] = cnt2[i] = 0;
    }
}
void solve(int u) {
    vis[u] = 1;
    int hi = 0, t = (s[u] == '(' ? 1 : -1);
    if (t == 1)
        tot1[1] = hi = 1;
    tot2[0] = 1;
    for (int i = hd[u]; i; i = eg[i].nx) {
        int v = eg[i].v;
        if (vis[v])
            continue;
        up = 0;
        dfs(v, 0, t, 0, t, 0);
        getmx(hi, up);
        calc();
    }
    for (int i = 0; i <= hi; ++i) {
        ans += (LL)tot1[i] * tot2[i];
        tot1[i] = tot2[i] = 0;
    }
    for (int i = hd[u]; i; i = eg[i].nx) {
        int v = eg[i].v;
        if (vis[v])
            continue;
        mx = INF, tot = sz[v];
        getrt(v, 0);
        solve(rt);
    }
}
int main() {
    scanf("%d%s", &n, s + 1);
    for (int i = 1, u, v; i < n; ++i) {
        scanf("%d%d", &u, &v);
        addE(u, v, i << 1), addE(v, u, i << 1 | 1);
    }
    mx = INF, tot = n;
    getrt(1, 0);
    solve(rt);
    printf("%lld\n", ans);
}
posted @ 2024-04-13 14:34  飞花阁  阅读(22)  评论(0)    收藏  举报
//雪花飘落效果