点分治
点分治
模板题引入:点分治1
题目描述
给定一棵有 \(n\) 个点的树,询问树上距离为 \(k\) 的点对是否存在。
输入格式
第一行两个数 \(n,m\)。
第 \(2\) 到第 \(n\) 行,每行三个整数 \(u, v, w\),代表树上存在一条连接 \(u\) 和 \(v\) 边权为 \(w\) 的路径。
接下来 \(m\) 行,每行一个整数 \(k\),代表一次询问。
输出格式
对于每次询问输出一行一个字符串代表答案,存在输出 AYE,否则输出 NAY。
样例 #1
样例输入 #1
2 1
1 2 2
2
样例输出 #1
AYE
提示
数据规模与约定
- 对于 \(30\%\) 的数据,保证 \(n\leq 100\)。
- 对于 \(60\%\) 的数据,保证 \(n\leq 1000\),\(m\leq 50\) 。
- 对于 \(100\%\) 的数据,保证 \(1 \leq n\leq 10^4\),\(1 \leq m\leq 100\),\(1 \leq k \leq 10^7\),\(1 \leq u, v \leq n\),\(1 \leq w \leq 10^4\)。
点分治介绍
树分治有点分治和边分治两种,适合处理大规模树上路径信息问题。
树上的路径可以分为两种:
- 经过根节点的路径
- 不经过根节点的路径
- 对于经过根节点的路径
可以预处理出每个点到根的路径,然后 dis[u][v]=dis[u][root]+dis[v][root]。
注意排除不合法路径(u,v 在同一棵子树内),先把前面子树中各点到根的距离存入一个队列 q[i],并且开一个布尔数组存入队列中的距离 judge[q[i]],再枚举当前子树中各点到根的距离 dis[j]。若询问距离 k 与 dis[j] 的差存在,即 judge[k-dis[j]] 为真,说明此解合法。
- 对于不经过根节点的路径
可以对子树不断分治,转化为经过根节点的路径。
如果是一棵平衡树,分治次数为 \(O(log\ n)\),每次分治后跑 \(n\) 个点,询问 \(m\) 次,每次判定答案是 \(O(nm)\),时间复杂度为 \(O(nmlog\ n)\)。
如果是一条链,且从链的一端开始分治,分治次数将退化为 \(O(n)\),时间复杂度为 \(O(n^2m)\)。
分治前,对每棵子树先找出重心做根即可。
点分治的四步操作:
- 找出树的重心做根
get_root() - 求出子树中的各点到根的距离
get_dis() - 对当前树统计答案
calc() - 分治各个子树,重复以上操作
divide()
#include<iostream>
#include<algorithm>
using namespace std;
const int N=10005;
const int INF=10000005;
struct node{int v,w,ne;}e[N<<1];
int h[N],idx; //加边
int del[N],siz[N],mxs,sum,root;//求根
int dis[N],d[N],cnt; //求距离
int ans[N],q[INF],judge[INF];//求路径
int n,m,ask[N];
void add(int u,int v,int w){
e[++idx].v=v; e[idx].w=w;
e[idx].ne=h[u]; h[u]=idx;
}
void getroot(int u,int fa){
siz[u]=1;
int s=0;
for(int i=h[u];i;i=e[i].ne){
int v=e[i].v;
if(v==fa||del[v])continue;
getroot(v,u);
siz[u]+=siz[v];
s=max(s,siz[v]);
}
s=max(s,sum-siz[u]);
if(s<mxs) mxs=s, root=u;
}
void getdis(int u,int fa){
dis[++cnt]=d[u];
for(int i=h[u];i;i=e[i].ne){
int v=e[i].v;
if(v==fa||del[v])continue;
d[v]=d[u]+e[i].w;
getdis(v,u);
}
}
void calc(int u){
judge[0]=1;
int p=0;
// 计算经过根u的路径
for(int i=h[u];i;i=e[i].ne){
int v=e[i].v;
if(del[v])continue;
// 求出子树v的各点到u的距离
cnt=0;
d[v]=e[i].w;
getdis(v,u);
// 枚举距离和询问,判定答案
for(int j=1;j<=cnt;++j)
for(int k=1;k<=m;++k)
if(ask[k]>=dis[j])
ans[k]|=judge[ask[k]-dis[j]];
// 记录合法距离
for(int j=1;j<=cnt;++j)
if(dis[j]<INF)
q[++p]=dis[j], judge[q[p]]=1;
}
// 清空距离数组
for(int i=1;i<=p;++i) judge[q[i]]=0;
}
void divide(int u){
// 计算经过根u的路径
calc(u);
// 对u的子树进行分治
del[u]=1;
for(int i=h[u];i;i=e[i].ne){
int v=e[i].v;
if(del[v])continue;
mxs=sum=siz[v];
getroot(v,0); //求根
divide(root); //分治
}
}
int main(){
scanf("%d%d",&n,&m);
for(int i=1;i<n;++i){
int u,v,w;
scanf("%d%d%d",&u,&v,&w);
add(u,v,w);add(v,u,w);
}
for(int i=1;i<=m;++i)
scanf("%d",&ask[i]);
mxs=sum=n;
getroot(1,0);
getroot(root,0); //重构siz[]
divide(root);
for(int i=1;i<=m;++i)
ans[i]?puts("AYE"):puts("NAY");
return 0;
}
- 本题还可以利用容斥原理统计答案
#include<iostream>
#include<algorithm>
using namespace std;
const int N=10005;
const int INF=10000005;
struct node{int v,w,ne;}e[N<<1];
int h[N],idx; //加边
int del[N],siz[N],mxs,sum,root;//求根
int dis[N],d[N],cnt; //求距离
int ans[N];//求路径
int n,m,ask[N];
void add(int u,int v,int w){
e[++idx].v=v; e[idx].w=w;
e[idx].ne=h[u]; h[u]=idx;
}
void getroot(int u,int fa){
siz[u]=1;
int s=0;
for(int i=h[u];i;i=e[i].ne){
int v=e[i].v;
if(v==fa||del[v])continue;
getroot(v,u);
siz[u]+=siz[v];
s=max(s,siz[v]);
}
s=max(s,sum-siz[u]);
if(s<mxs) mxs=s, root=u;
}
void getdis(int u,int fa){
dis[++cnt]=d[u];
for(int i=h[u];i;i=e[i].ne){
int v=e[i].v;
if(v==fa||del[v])continue;
d[v]=d[u]+e[i].w;
getdis(v,u);
}
}
void calc(int u, int w, int sign){
cnt = 0,d[u] = w;
getdis(u, 0);//求距离
sort(dis+1, dis+cnt+1);
for(int i=1;i<=m;i++){
int l=1, r=cnt;
while(l < r){
if(dis[l]+dis[r]<=ask[i]){
if(dis[l]+dis[r]==ask[i])ans[i]+=sign;
++l;
}
else --r;
}
}
}
void divide(int u){
calc(u, 0, 1); //求答案
del[u] = 1;
for(int i=h[u];i;i=e[i].ne){
int v = e[i].v;
if(del[v]) continue;
calc(v, e[i].w, -1); //容斥
mxs =sum = siz[v];
getroot(v, u); //求根
divide(root); //分治
}
}
int main(){
scanf("%d%d",&n,&m);
for(int i=1;i<n;++i){
int u,v,w;
scanf("%d%d%d",&u,&v,&w);
add(u,v,w);add(v,u,w);
}
for(int i=1;i<=m;++i)
scanf("%d",&ask[i]);
mxs=sum=n;
getroot(1,0);
getroot(root,0); //重构siz[]
divide(root);
for(int i=1;i<=m;++i)
ans[i]?puts("AYE"):puts("NAY");
return 0;
}
例题:树上括号序列
题目描述
给定一棵树,树上每个节点的字符可以是 ( 或者 )。
求有序点对 的数量,使得 到 的路径上的点的字符(包括 和 )所构成的括号序列合法。
输入格式
第一行一个整数 ,表示节点数量。
第二行一个字符串,由 ( 和 ) 组成,表示每个节点的字符。
接下来 行,每行 个整数 和 ,表示树上的一条边。
输出格式
一行一个整数,表示点对的数量。
样例
输入1
4
(())
1 2
2 3
3 4
输出1
2
样例解释
合法的点对有 \((1,4)\) 和 \((2,3)\)。
输入2
5
())((
1 2
2 3
2 4
3 5
输出2
3
样例解释
合法的点对有 \((1,2)\),\((4,2)\) 和 \((5,3)\)。
输入3
7
)()()((
1 2
1 3
1 6
2 4
4 5
5 7
输出3
6
数据范围与提示
| 子任务编号 | \(n\le\) | 特殊性质 | 分值 |
|---|---|---|---|
| \(1\) | \(10^3\) | \(10\) | |
| \(2\) | \(3\times 10^5\) | 树是一条链 | \(30\) |
| \(3\) | \(3\times 10^5\) | \(60\) |
#include <iostream>
using namespace std;
using LL = long long;
const int N = 3e5 + 7, INF = 0x3F3F3F3F;
int n, cnt, hd[N], rt, tot, sz[N], mx;
char s[N];
struct Edge {
int v, nx;
} eg[N << 1];
void addE(int u, int v, int c) { eg[c] = { v, hd[u] }, hd[u] = c; }
inline void getmx(int &x, int y) {
if (x < y)
x = y;
}
inline void getmn(int &x, int y) {
if (x > y)
x = y;
}
bool vis[N];
void getrt(int u, int fa) {
sz[u] = 1;
int cur = 0;
for (int i = hd[u]; i; i = eg[i].nx) {
int v = eg[i].v;
if (vis[v] || v == fa)
continue;
getrt(v, u);
getmx(cur, sz[v]);
sz[u] += sz[v];
}
getmx(cur, tot - sz[u]);
if (cur < mx)
mx = cur, rt = u;
}
// tot1[i] 左括号比右括号多i个的前缀数量, tot2[i] 右括号比左括号多i的后缀数量
int up, cnt1[N], cnt2[N], tot1[N], tot2[N];
void dfs(int u, int fa, int s1, int s2, int mn1, int mn2) {
sz[u] = 1;
int t = (s[u] == '(' ? 1 : -1);
s1 += t, s2 -= t, mn1 = min(mn1, 0) + t, mn2 = min(mn2, 0) - t;
if (mn1 >= 0)
++cnt1[s1], getmx(up, s1);
if (mn2 >= 0)
++cnt2[s2], getmx(up, s2);
for (int i = hd[u]; i; i = eg[i].nx) {
int v = eg[i].v;
if (vis[v] || v == fa)
continue;
dfs(v, u, s1, s2, mn1, mn2);
sz[u] += sz[v];
}
}
LL ans;
void calc() {
for (int i = 0; i <= up; ++i) {
ans -= (LL)cnt1[i] * cnt2[i];
tot1[i] += cnt1[i], tot2[i] += cnt2[i];
cnt1[i] = cnt2[i] = 0;
}
}
void solve(int u) {
vis[u] = 1;
int hi = 0, t = (s[u] == '(' ? 1 : -1);
if (t == 1)
tot1[1] = hi = 1;
tot2[0] = 1;
for (int i = hd[u]; i; i = eg[i].nx) {
int v = eg[i].v;
if (vis[v])
continue;
up = 0;
dfs(v, 0, t, 0, t, 0);
getmx(hi, up);
calc();
}
for (int i = 0; i <= hi; ++i) {
ans += (LL)tot1[i] * tot2[i];
tot1[i] = tot2[i] = 0;
}
for (int i = hd[u]; i; i = eg[i].nx) {
int v = eg[i].v;
if (vis[v])
continue;
mx = INF, tot = sz[v];
getrt(v, 0);
solve(rt);
}
}
int main() {
scanf("%d%s", &n, s + 1);
for (int i = 1, u, v; i < n; ++i) {
scanf("%d%d", &u, &v);
addE(u, v, i << 1), addE(v, u, i << 1 | 1);
}
mx = INF, tot = n;
getrt(1, 0);
solve(rt);
printf("%lld\n", ans);
}

浙公网安备 33010602011771号