关于点分治
点分治是一个常见的维护树上点对的数据结构,有静态点分治与动态点分治(点分树),本篇只介绍静态点分治。
例题:点分治1
题意:给定一棵有 n 个点的树,m次询问,询问树上距离为 k 的点对是否存在。
n <= 10^4,m<=100,k <= 10^7
思路:看到该题可以想到一个暴力思路,对于每个点进行一次dfs即可。这是直接枚举点再枚举路径的想法,但这样我们会重复枚举到许多无效的点与路径,考虑减少这种枚举。我们可以尝试从路径入手,对于一棵有根树,树上的路径可分为两类,一类经过根节点,一类不经过根节点。
对于经过根节点的路径,其由两个在根节点的不同子节点的子树内的节点连接而成,可将其分解为两条路径,分别都以根节点作为一个端点,我们可以在线性时间内处理出这种被分解的路径的长度并存起来,然后找是否有两条加起来长度为k的路径即可。
对于不经过根节点的路径,其由两个在根节点的相同子节点的子树内的节点连接而成,我们可以考虑进行递归,进行所谓的“换根”操作,使得这种路径变为第一种路径。
这样我们每一次都把根节点的每个子树进行递归分治,若最终递归h层,则时间复杂度为O(hn),考虑优化这个层数,根据树的重心的性质可得,每一次以子树的重心作为根节点进行传递,那么每一次该节点的最大子树的大小不会超过总节点的一半,因此最多log(n)层即可完成递归。
两种写法:
桶统计:
#include<bits/stdc++.h>
using namespace std;
#define int long long
int const maxn = 6e4 + 10;
int const M = 1e7 + 10;
struct node{
int v,w,nxt;
}e[maxn];
int n,m,q[maxn],cnt,head[maxn],rt,siz[maxn],mxp[maxn],d[maxn],mine[M],tot,a[maxn];
bool vis[maxn],ok[maxn];
void add(int u,int v,int w){
e[++cnt].v = v;
e[cnt].w = w;
e[cnt].nxt = head[u];
head[u] = cnt;
}
void getrt(int u,int fa,int sum){
siz[u] = 1;
mxp[u] = 0;
for(int i = head[u];i;i = e[i].nxt){
int v = e[i].v;
if(v==fa||vis[v])continue;
getrt(v,u,sum);
siz[u]+=siz[v];
mxp[u] = max(mxp[u],siz[v]);
}
mxp[u] = max(mxp[u],sum - siz[u]);
if(!rt||mxp[u] < mxp[rt])rt = u;
}
void getdis(int u,int fa,int dis){
if(dis>M)return ;
a[++tot] = u;
d[u] = dis;
for(int i = head[u];i;i = e[i].nxt){
int v = e[i].v;
if(v==fa||vis[v])continue;
getdis(v,u,dis + e[i].w);
}
}
void calc(int u){
tot = 0;
d[u] = 0;
mine[0] = u;
for(int i = head[u];i;i = e[i].nxt){
int v = e[i].v;
if(vis[v])continue;
int ptot = tot;
getdis(v,u,e[i].w);
for(int j = ptot;j <= tot;j ++){
for(int p = 1;p <= m;p ++){
if(d[a[j]] > q[p]||ok[p])continue;
if(d[mine[q[p] - d[a[j]]]] + d[a[j]] == q[p]){
ok[p] = 1;
}
}
}
for(int j = ptot;j <= tot;j ++){
mine[d[a[j]]] = a[j];
}
}
for(int i = 0;i <= tot;i ++){
mine[d[a[i]]] = 0;
}
}
void solve(int u){
vis[u] = 1;
calc(u);
for(int i = head[u];i;i = e[i].nxt){
int v = e[i].v;
if(vis[v])continue;
rt = 0;
getrt(v,u,siz[v]);
solve(rt);
}
}
signed main(){
cin >> n >> m;
for(int i = 1;i < n;i ++){
int u,v,w;
cin >> u >> v>> w;
add(u,v,w);
add(v,u,w);
}
for(int i = 1;i <= m;i ++){
cin >> q[i];
}
mxp[0] = n;
getrt(1,0,n);
solve(rt);
for(int i = 1;i <= m;i ++){
if(ok[i]){
cout <<"AYE"<<'\n';
}else {
cout <<"NAY"<<'\n';
}
}
}
双指针:
#include<bits/stdc++.h>
using namespace std;
#define int long long
int const maxn = 2e4+10;
struct node{
int v,w,nxt;
}p[maxn];
int head[maxn];
int cnt,tot;
int n,m;
int siz[maxn],mxp[maxn],d[maxn],b[maxn],a[maxn],q[maxn],res[maxn];
bool vis[maxn];
void add(int u,int v,int w){
p[++cnt].v = v;
p[cnt].w = w;
p[cnt].nxt = head[u];
head[u] = cnt;
}
int rt;
void getroot(int u,int fa,int sum){
siz[u] = 1;mxp[u] = 0;
for(int i = head[u];i; i = p[i].nxt){
int v = p[i].v;
if(v==fa||vis[v])continue;
getroot(v,u,sum);
siz[u] += siz[v];
mxp[u] = max(mxp[u],siz[v]);
}
mxp[u] = max(mxp[u],sum - siz[u]);
if(!rt || mxp[u] < mxp[rt]){
rt = u;
}
}
bool cmp(int x,int y){
return d[x] < d[y];
}
void getdis(int u,int fa,int dis,int bel){
a[++tot] = u;
d[u] = dis;
b[u] = bel;
for(int i = head[u];i; i = p[i].nxt){
int v = p[i].v;
if(v==fa||vis[v])continue;
getdis(v,u,dis + p[i].w,bel);
}
}
void calc(int u){
tot = 0;
a[++tot] = u;
d[u] = 0;
b[u] = u;
for(int i = head[u];i; i = p[i].nxt){
int v = p[i].v;
if(vis[v])continue;
getdis(v,u,p[i].w,v);
}
sort(a+1,a+tot+1,cmp);
for(int i = 1;i <= m;i ++){
int l = 1,r = tot;
if(res[i])continue;
while(l < r){
if(d[a[l]] + d[a[r]] > q[i])r--;
else if(d[a[l]] + d[a[r]] < q[i])l++;
else if(b[a[l]]==b[a[r]]){
if(d[a[r]]==d[a[r-1]])r--;
else l++;
}else {
res[i]++;
break;
}
}
}
}
void solve(int u){
vis[u] = 1;
calc(u);
for(int i = head[u];i;i = p[i].nxt){
int v = p[i].v;
if(vis[v])continue;
rt = 0;
getroot(v,0,siz[v]);
solve(rt);
}
}
signed main(){
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
cin >> n >> m ;
for(int i = 1;i <= n-1 ;i ++){
int u,v,w;
cin >> u >> v >> w;
add(u,v,w);
add(v,u,w);
}
for(int i = 1;i <= m;i ++){
cin >> q[i] ;
}
mxp[0] = n;
getroot(1,0,n);
solve(rt);
for(int i = 1;i <= m;i ++){
if(res[i]){
cout << "AYE"<<'\n';
}else {
cout << "NAY"<<'\n';
}
}
return 0;
}
综合两种写法可以发现,第一种桶统计的写法可维护信息广,但易卡常,而且内存要求高,第二种双指针写法时间常数小,简洁易懂。
例题:Tree
题意:给定一棵 n 个节点的树,每条边有边权,求出树上两点距离小于等于 k 的点对数量。
n <= 4 * 10^4
k <= 2 * 10^4
如上的思路,可用容斥处理处在同一子树内的不合法路径,算出根节点的路径数量之和再减去位于同一子树内的路径数量之和即可。
双指针代码如下:
#include<bits/stdc++.h>
using namespace std;
#define int long long
int const maxn = 8e4+10;
struct node{
int v,w,nxt;
}e[maxn];
int n,k;
int head[maxn],cnt;
void add(int u,int v,int w){
e[++cnt].w = w;
e[cnt].v = v;
e[cnt].nxt = head[u];
head[u] = cnt;
}
bool vis[maxn];
int siz[maxn],mxp[maxn],rt;
void getrt(int u,int fa,int sum){
siz[u] = 1;
mxp[u] = 0;
for(int i = head[u];i;i = e[i].nxt){
int v = e[i].v;
if(v==fa||vis[v])continue;
getrt(v,u,sum);
siz[u]+=siz[v];
mxp[u] = max(mxp[u],siz[v]);
}
mxp[u] = max(mxp[u],sum - siz[u]);
if(!rt||mxp[u] < mxp[rt])rt = u;
}
int tot;
int a[maxn],b[maxn],d[maxn];
bool cmp(int x,int y){
return d[x] < d[y];
}
void getdis(int u,int fa,int dis){
a[++tot] = dis;
for(int i = head[u];i; i =e[i].nxt){
int v = e[i].v;
if(v==fa||vis[v])continue;
getdis(v,u,dis + e[i].w);
}
}
int calc(int u,int w){
tot = 0;
getdis(u,0,w);
sort(a+1,a+tot+1);
int l = 1,r = tot,res = 0;
while(l <= r){
if(a[l] + a[r] <= k){
res += (r-l);
l++;
}else {
r--;
}
}
return res;
}
int ans;
void solve(int u){
vis[u] = 1;
ans += calc(u,0);
for(int i = head[u];i;i = e[i].nxt){
int v = e[i].v;
if(vis[v])continue;
ans -= calc(v,e[i].w);
rt = 0;
getrt(v,0,siz[v]);
solve(rt);
}
}
signed main(){
cin >> n;
for(int i = 1;i < n;i ++){
int u,v,w;
cin >> u >> v >> w;
add(u,v,w);
add(v,u,w);
}
cin >> k ;
getrt(1,0,n);
solve(rt);
cout << ans;
}
例题:Race
题意:给一棵树,每条边有权。求一条简单路径,权值和等于 k,且边的数量最小。
n<=2e5
k<=1e6
思路:该题用双指针写的话会非常难处理,可以想到一种情况,对于d[a[l]]+d[a[r]]是合法的的情况,可能a[r]后的一段值都和a[l]相等,a[r]前的一段都和a[r]相等,我们需要让其两两分别配对并对边数取min。而用桶统计写法就会比较清晰。
#include<bits/stdc++.h>
using namespace std;
#define int long long
int const maxn = 6e5 + 10;
int const M = 1e6+1000;
int rt;
struct node{
int v,w,nxt;
}e[maxn];
int head[maxn],cnt;
void add(int u,int v,int w){
e[++cnt].v = v;
e[cnt].w = w;
e[cnt].nxt = head[u];
head[u] = cnt;
}
bool vis[maxn];
int n,k,ans;
int dep[maxn],siz[maxn],mxp[maxn],son[maxn],top[maxn],si[maxn],f1[maxn];
void getrt(int u,int fa,int sum){
siz[u] = 1;
mxp[u] = 0;
for(int i = head[u];i;i = e[i].nxt){
int v = e[i].v;
if(v==fa||vis[v])continue;
getrt(v,u,sum);
siz[u] += siz[v];
mxp[u] = max(mxp[u],siz[v]);
}
mxp[u] = max(mxp[u],sum - siz[u]);
if(mxp[u] < mxp[rt]){
rt = u ;
}
}
int d2[maxn],d[maxn],a[maxn],b[maxn],mine[M];
int tot;
void getdis(int u,int fa,int dis,int dis2,int bel){
if(dis > k)return ;
d[++tot] = dis;
d2[tot] = dis2;
for(int i = head[u];i;i = e[i].nxt){
int v = e[i].v;
if(v==fa||vis[v])continue;
getdis(v,u,dis+e[i].w,dis2+1,bel);
}
}
void calc(int u){
tot = 0;
mine[0] = 0;
for(int i = head[u];i;i = e[i].nxt){
int v = e[i].v;
int ptot = tot;
if(vis[v])continue;
getdis(v,u,e[i].w,1,v);
for(int j = ptot+1;j <= tot;j ++){
ans = min(ans,mine[k - d[j]] + d2[j]);
}
for(int j = ptot+1;j <= tot;j ++){
mine[d[j]] = min(mine[d[j]],d2[j]);
}
}
for(int i = 1;i <= tot;i++){
mine[d[i]] = 1e9+10;
}
}
void solve(int u){
vis[u] = 1;
calc(u);
for(int i = head[u];i;i = e[i].nxt){
int v = e[i].v;
if(vis[v])continue;
rt = 0;
getrt(v,0,siz[v]);
solve(rt);
}
}
signed main(){
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
ans = 1e9+10;
cin >> n >> k ;
for(int i = 1;i <= n - 1;i ++){
int u,v,w;
cin >> u >> v >> w;
add(u+1,v+1,w);
add(v+1,u+1,w);
}
mxp[0] = n+1;
memset(mine,0x3f,sizeof mine);
rt = 0;
getrt(1,0,n);
solve(rt);
if(ans<n){
cout << ans;
}else {
cout << -1;
}
return 0;
}

浙公网安备 33010602011771号