可持久化线段树(可持久化树组,主席树)

可持久化线段树(可持久化树组,主席树)

概念

可持久化线段树线段树的扩展,它的每一个操作都是建立在线段树的某一个历史版本上(相当于每一步操作都会产生一个新线段树,我们在指定线段树上进行操作

基本思想

既然是在某个历史版本上操作,那么我们的每一步操作都建立一个全新的线段树。但很容易发现,这样做会浪费大量空间与时间

  • 对于每一次建树,时间复杂度至少为 \(O(n)\) (节点数大约为 \(4n\))显然频繁建树会导致时间复杂度偏高
  • 对于每一次建树,空间复杂度大致为 \(O(n)\) (节点数大约为 \(4n\))显然频繁建树会导致空间复杂度偏高

我们观察每一次修改操作:每一次修改最多更新 \(log_2n\) 个节点的值,也就是说只会更新从根到叶子节点的一条链!如下图(修改 \(2\)):

graph TB 1(1,4) 2(1,2) 3(3,4) 4(1,1) 5(2,2) 6(3,3) 7(4,4) 1---2 1---3 2---4 2---5 3---6 3---7

我们只会修改 \([1,4]\)\([1,2]\)\([2,2]\) 三个节点的值:

graph TB 1(1,4) 2(1,2) 3( ) 4( ) 5(2,2) 6( ) 7( ) 1---2 1---3 2---4 2---5 3---6 3---7

所以,我们只需要再线段树上新增一条链即可:

graph TB n1[1,4] n2[1,2] n5[2,2] n1---n2 n2---n5 n2---z5 z1(1,4) z2(1,2) z3(3,4) z4(1,1) z5(2,2) z6(3,3) z7(4,4) n1[1,4] n2[1,2] n5[2,2] z1---z2 z1---z3 z2---z4 z2---z5 z3---z6 z3---z7

这样修改,时间复杂度为 \(O(log_2n)\) ,空间复杂度为 \(O(nlog_2n)\)

算法实现

对于一颗可持久化线段树,我们让根节点作为访问某个历史版本的媒介(每一次操作都会更新根节点

建树

可持久化线段树的建树与普通线段树相同:

struct node {
  int lson,rson; // 左右儿子
  ll val; // 值
  node(int lson = -1,int rson = -1,int val = 0) lson(lson),rson(rson),val(val) {}; // -1代表没有左右儿子
};
node tree[maxn * 25 + 5]; // nlog_2n 各节点,每次增加 log_2n 个
int root[maxm + 5]; // 不同版本的根节点
int nodecnt = 0;
int build(int l,int r) {
  int u = ++ nodecnt; // 新建节点
  if (l == r) { // 叶子节点
  	tree[u].val = a[l]; // a[i]为初始值
    return u; // 返回节点
  }
  int mid = (l + r) >> 1;
  tree[u].lson = build(l,mid),tree[u].rson = build(mid + 1,r); // 建造左右子树
  return u;
}

主程序

root[0] = build(1,n); // 0号版本

更新(建链)

可持久化线段树的建链与单点修改类似:

  1. 从根节点开始操作,访问每一个包含操作位置的节点
  2. 对于每一个访问到的节点 \(u\),先将 \(u\) 赋值为操作版本的 \(u\) 点(原来版本的左右子节点与值),再进行修改
int rebuild(int pre,int l,int r,int pos,int val) { // pre:操作的版本的父节点 pos:操作位置 val:值
  int u = ++ nodecnt; // 新建节点
  tree[u] = tree[pre]; // 复制原来版本的节点
  if (l == r) {
    tree[u].val = val; // 赋值(本文为基础可持久化数组的写法)
   	return u;
  }
  int mid = (l + r) >> 1;
  if (pos <= mid) tree[u].lson = rebuild(tree[pre].lson,l,mid,pos,val); // 位置在左子树
  else tree[u].rson = rebuild(tree[pre].lson,mid + 1,r,pos,val); // 位置在右子树
  return u;
}

主程序

root[i] = rebuild(root[pre],1,n,pos,val); // 0号版本

查询(单点)

可持久化线段树的单点查询与普通线段树的单点查询类似:

int query(int pre,int l,int r,int pos) {
  if (l == r) return tree[pre].val; // 叶子节点返回
  int mid = (l + r) >> 1;
  if (pos <= mid) return query(tree[pre].lson,l,mid,pos); // 左子节点
  else return query(tree[pre].rson,mid + 1,r,pos); // 右子节点
}

例题

  1. 洛谷P3919 【模板】可持久化线段树 1(可持久化数组)

    #include <bits/stdc++.h>
    #define ll long long
    using namespace std;
    const int maxn = 1e6;
    struct node { 
    	int lson,rson,val; 
    	node(int lson = -1,int rson = -1,int val = 0) : lson(lson),rson(rson),val(val) {};
    };
    int n,m,nodecnt = 0;
    node tree[maxn * 25 + 5];
    int root[maxn];
    int a[maxn + 5];
    int build(int l,int r) {
    	int u = ++ nodecnt;
    	if (l == r) {
    		tree[u].val = a[l];
    		return u;
    	}
    	int mid = (l + r) >> 1;
    	tree[u].lson = build(l,mid);
    	tree[u].rson = build(mid + 1,r);
    	return u;
    }
    int rebuild(int pre,int l,int r,int pos,int val) {
    	int u = ++ nodecnt;
    	tree[u] = tree[pre];
    	if (l == r) {
    		tree[u].val = val;
    		return u;
    	}
    	int mid = (l + r) >> 1;
    	if (pos <= mid) tree[u].lson = rebuild(tree[u].lson,l,mid,pos,val);
    	else tree[u].rson = rebuild(tree[u].rson,mid + 1,r,pos,val);
    	return u;
    }
    int query(int pre,int l,int r,int pos) {
    	if (l == r) return tree[pre].val;
    	int mid = (l + r) >> 1;
    	if (pos <= mid) return query(tree[pre].lson,l,mid,pos);
    	else return query(tree[pre].rson,mid + 1,r,pos);
    }
    int main() {
    	scanf("%d %d",&n,&m);
    	for (int i = 1;i <= n;i ++) scanf("%d",&a[i]);
    	root[0] = build(1,n);
    	for (int i = 1;i <= m;i ++) {
    		int version,op; scanf("%d %d",&version,&op);
    		if (op == 1) {
    			int pos,val; scanf("%d %d",&pos,&val);
    			root[i] = rebuild(root[version],1,n,pos,val);
    		} else {
    			int pos; scanf("%d",&pos);
    			printf("%d\n",query(root[version],1,n,pos));
    			root[i] = root[version];
    		}
    	}
    	return 0;
    }
    
  2. 洛谷P3834 【模板】可持久化线段树 2(区间 \(k\) 小值)

    #include <bits/stdc++.h>
    #define ll long long
    #define mid ((l + r) >> 1)
    using namespace std;
    const int maxn = 2e5;
    struct node { 
    	int lson,rson,sum; 
    	node(int lson = -1,int rson = -1,int sum = 0) : lson(lson),rson(rson),sum(sum) {};
    };
    int n,m,nodecnt = 0;
    node tree[maxn * 25 + 5];
    int root[maxn];
    int t[maxn + 5],a[maxn + 5];
    void init() {
    	sort(t + 1,t + n + 1);
    	for (int i = 1;i <= n;i ++) a[i] = lower_bound(t + 1,t + n + 1,a[i]) - t;
    }
    int build(int l,int r) {
    	int u = ++ nodecnt;
    	if (l == r) return u;
    	tree[u].lson = build(l,mid);
    	tree[u].rson = build(mid + 1,r);
    	return u;
    }
    int rebuild(int pre,int l,int r,int pos) {
    	int u = ++ nodecnt;
    	tree[u] = tree[pre];
    	tree[u].sum ++;
    	if (l == r) return u;
    	if (pos <= mid) tree[u].lson = rebuild(tree[pre].lson,l,mid,pos);
    	else tree[u].rson = rebuild(tree[pre].rson,mid + 1,r,pos);
    	return u;
    }
    int query(int pre,int last,int l,int r,int k) {
    	if (l == r) return l;
    	int x = tree[tree[last].lson].sum - tree[tree[pre].lson].sum;
    	if (k <= x) return query(tree[pre].lson,tree[last].lson,l,mid,k);
    	else return query(tree[pre].rson,tree[last].rson,mid + 1,r,k - x);
    }
    int main() {
      scanf("%d %d",&n,&m);
      for (int i = 1;i <= n;i ++) {
      	scanf("%d",&t[i]);
      	a[i] = t[i];
    	}
      init();
      root[0] = build(1,n);
      for (int i = 1;i <= n;i ++) root[i] = rebuild(root[i - 1],1,n,a[i]);
      for (int i = 1;i <= m;i ++) {
      	int x,y,k; scanf("%d %d %d",&x,&y,&k);
      	printf("%d\n",t[query(root[x - 1],root[y],1,n,k)]);
    	}
    	return 0;
    }
    

扩展:

可持久化并查集

利用可持久化数组编写可持久化并查集

  1. 合并:进行按秩合并(类似于把低的树接到高的树上,把少的树接到多的树上)

  2. 查询:我们不进行路径压缩(时间复杂度增加),但这样可能导致栈溢出

以下代码通过比较集合大小进行按秩合并

#include <bits/stdc++.h>
using namespace std;
const int maxn = 4e5;
struct TREE {
	struct node { int lson,rson,val; };
	int n,m,nodecnt = 0;
	node tree[maxn * 25 + 5];
	int root[maxn + 5];
	int st[maxn + 5];
	int build(int l,int r) {
		int u = ++ nodecnt;
		if (l == r) {
			tree[u].val = st[l];
			return u;
		}
		int mid = (l + r) >> 1;
		tree[u].lson = build(l,mid);
		tree[u].rson = build(mid + 1,r);
		return u;
	}
	int rebuild(int pre,int l,int r,int pos,int val) {
		int u = ++ nodecnt;
		tree[u] = tree[pre];
		if (l == r) {
			tree[u].val = val;
			return u;
		}
		int mid = (l + r) >> 1;
		if (pos <= mid) tree[u].lson = rebuild(tree[u].lson,l,mid,pos,val);
		else tree[u].rson = rebuild(tree[u].rson,mid + 1,r,pos,val);
		return u;
	}
	int query(int pre,int l,int r,int pos) {
		if (l == r) return tree[pre].val;
		int mid = (l + r) >> 1;
		if (pos <= mid) return query(tree[pre].lson,l,mid,pos);
		else return query(tree[pre].rson,mid + 1,r,pos);
	}
};
int n,m;
TREE fa,sz;
int now_version = 0;
int find(int version,int x) {
	while (true) {
		int fx = fa.query(fa.root[version],1,n,x);
		if (fx == x) return x;
		x = fx;
	}		
}
void merge(int pre_version,int x,int y) {
	int fx = find(pre_version,x),fy = find(pre_version,y);
	if (fx == fy) return ;
	int sx = sz.query(sz.root[pre_version],1,n,fx),sy = sz.query(sz.root[pre_version],1,n,fy);
	if (sx <= sy) {
		fa.root[pre_version] = fa.rebuild(fa.root[pre_version],1,n,fx,fy);
		sz.root[pre_version] = sz.rebuild(sz.root[pre_version],1,n,fy,sx + sy);
	} else {
		fa.root[pre_version] = fa.rebuild(fa.root[pre_version],1,n,fy,fx);
		sz.root[pre_version] = sz.rebuild(sz.root[pre_version],1,n,fx,sx + sy);
	}
}
int main() {
	scanf("%d %d",&n,&m);
	for (int i = 1;i <= n;i ++) fa.st[i] = i;
	fa.root[0] = fa.build(1,n);
	for (int i = 1;i <= n;i ++) sz.st[i] = 1;
	sz.root[0] = sz.build(1,n);
	for (int i = 1;i <= m;i ++) {
		int op; scanf("%d",&op);
		fa.root[i] = fa.root[i - 1];
		sz.root[i] = sz.root[i - 1];
		if (op == 1) {
			int x,y; scanf("%d %d",&x,&y);
			merge(i,x,y);
		} else if (op == 2) {
			int x; scanf("%d",&x);
			fa.root[i] = fa.root[x];
			sz.root[i] = sz.root[x];
		} else if (op == 3) {
			int x,y; scanf("%d %d",&x,&y);
			if (find(i,x) == find(i,y)) printf("1\n");
			else printf("0\n");
		}
	}
	return 0;
}
posted @ 2025-04-12 23:02  nightmare_lhh  阅读(12)  评论(0)    收藏  举报