树上差分
树上差分与线性差分差不多,只不过是在树上进行差分,每次将两个点x和y的标志加1,将lca(x,y)和fa(lca(x,y))的标志减1,最后来一次深搜求和,就可以得到值了
下面给出几道例题
1.P3128 [USACO15DEC] Max Flow P
解析:
树上差分板子题,直接套班子,求完值后,求最大值即可
代码:
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 2e6+39+7;
int depth[N],f[N][21],n,head[N],tot,k,cnt[N],ans;
struct node{
int u,v;
}edge[N<<1];
void add(int x,int y){
edge[++tot].u=head[x];
edge[tot].v=y;
head[x]=tot;
}
void dfs(int u,int fa){
depth[u]=depth[fa]+1;
f[u][0]=fa;
for(int i=1;(1<<i)<=depth[u];i++)f[u][i]=f[f[u][i-1]][i-1];
for(int i=head[u];i;i=edge[i].u){
if(edge[i].v==fa)continue;
dfs(edge[i].v,u);
}
}
int lca(int x,int y){
if(depth[x]>depth[y])swap(x,y);
for(int i=20;i>=0;i--)if(depth[y]-(1<<i)>=depth[x])y=f[y][i];
if(x==y)return x;
for(int i=20;i>=0;i--){
if(f[x][i]==f[y][i])continue;
x=f[x][i],y=f[y][i];
}
return f[x][0];
}
void dfss(int u,int fa){
for(int i=head[u];i;i=edge[i].u){
if(edge[i].v==fa)continue;
dfss(edge[i].v,u);
cnt[u]+=cnt[edge[i].v];
}
}
int main(){
cin>>n>>k;
for(int i=1,a,b;i<n;i++)cin>>a>>b,add(a,b),add(b,a);
dfs(1,0);
for(int i=1,x,y,la;i<=k;i++){
cin>>x>>y;
la=lca(x,y);
cnt[x]++;cnt[y]++;
cnt[la]--;cnt[f[la][0]]--;
}
dfss(1,0);
for(int i=1;i<=n;i++)ans=max(ans,cnt[i]);
cout<<ans;
return 0;
}
解析:
树上差分板子题,每个点都会多算1次,所以,在深搜求完值之后,需要把每个数减1,依次输出即可
代码:
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 2e6+39+7;
int a[N],depth[N],f[N][21],n,head[N],tot,k,cnt[N],ans;
struct node{
int u,v;
}edge[N<<1];
void add(int x,int y){
edge[++tot].u=head[x];
edge[tot].v=y;
head[x]=tot;
}
void dfs(int u,int fa){
depth[u]=depth[fa]+1;
f[u][0]=fa;
for(int i=1;(1<<i)<=depth[u];i++)f[u][i]=f[f[u][i-1]][i-1];
for(int i=head[u];i;i=edge[i].u){
if(edge[i].v==fa)continue;
dfs(edge[i].v,u);
}
}
int lca(int x,int y){
if(depth[x]>depth[y])swap(x,y);
for(int i=20;i>=0;i--)if(depth[y]-(1<<i)>=depth[x])y=f[y][i];
if(x==y)return x;
for(int i=20;i>=0;i--){
if(f[x][i]==f[y][i])continue;
x=f[x][i],y=f[y][i];
}
return f[x][0];
}
void dfss(int u,int fa){
for(int i=head[u];i;i=edge[i].u){
if(edge[i].v==fa)continue;
dfss(edge[i].v,u);
cnt[u]+=cnt[edge[i].v];
}
}
int main(){
cin>>n;
for(int i=1;i<=n;i++)cin>>a[i];
for(int i=1,a,b;i<n;i++){
cin>>a>>b;
add(a,b);
add(b,a);
}
dfs(1,0);
for(int i=1,LCA;i<n;i++){
LCA=lca(a[i],a[i+1]);
cnt[a[i]]++;cnt[a[i+1]]++;
cnt[LCA]--;cnt[f[LCA][0]]--;
}
dfss(1,0);
for(int i=2;i<=n;i++)cnt[a[i]]--;
for(int i=1;i<=n;i++)cout<<cnt[i]<<'\n';
return 0;
}
解析:
这道题使用了树上差分和记录路径的方法,预处理init数组,fa数组,dep数组等,进行求解,使用静态算法,存储每一次的问题,和两点之间的距离和lca,使用二分枚举时间,即可得到答案
代码:
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 1e6+39+7;
struct node{
int x,y,lca,dis;
bool operator <(const node &a)const{
return dis<a.dis;
}
}query[N];
struct edg{
int to,next,w;
}e[N<<1];
int l,r,m,dep[N],fa[N][21],d[N],n,head[N],tot=-1,k,ans,init[N],cnt[N];
void add(int x,int y,int z){
e[++tot]=(edg){y,head[x],z};
head[x]=tot;
}
void dfs(int x,int father,int dis){
dep[x]=dep[father]+1;
fa[x][0]=father;init[x]=dis;
for(int i=1;(1<<i)<=dep[x];i++)fa[x][i]=fa[fa[x][i-1]][i-1];
for(int i=head[x];~i;i=e[i].next){
int y=e[i].to;
if(y==father)continue;
d[y]=d[x]+e[i].w;
dfs(y,x,e[i].w);
}
}
int lca(int x,int y){
if(dep[x]>dep[y])swap(x,y);
for(int i=20;i>=0;i--)if(dep[y]-(1<<i)>=dep[x])y=fa[y][i];
if(x==y)return x;
for(int i=20;i>=0;i--){
if(fa[x][i]==fa[y][i])continue;
x=fa[x][i];y=fa[y][i];
}
return fa[x][0];
}
void dfss(int u,int father){
for(int i=head[u];~i;i=e[i].next){
int y=e[i].to;
if(y==father)continue;
dfss(y,u);
cnt[u]+=cnt[y];
}
}
bool ok(int x){
int num=0,now=0;
for(int i=1;i<=n;i++)cnt[i]=0;
for(int i=1;i<=m;i++){
if(query[i].dis<=x)continue;
cnt[query[i].x]++;cnt[query[i].y]++;
cnt[query[i].lca]-=2;
num++;
}
dfss(1,0);
for(int i=1;i<=n;i++)if(cnt[i]==num)now=max(now,init[i]);
return query[m].dis-now<=x;
}
int main(){
memset(head,-1,sizeof(head));
cin>>n>>m;
for(int i=1,x,y,z;i<n;i++){
cin>>x>>y>>z;
add(x,y,z);add(y,x,z);
}
dfs(1,0,0);
for(int i=1,x,y;i<=m;i++){
cin>>x>>y;
query[i].lca=lca(x,y);
query[i].dis=d[x]+d[y]-2*d[query[i].lca];
r=max(r,query[i].dis);
query[i].x=x;query[i].y=y;
}
sort(query+1,query+m+1);
while(l<=r){
int mid=(l+r)/2;
if(ok(mid))r=mid-1;
else l=mid+1;
}
cout<<l;
return 0;
}

浙公网安备 33010602011771号