主席树
haha这将是我最简洁的一篇博文
引入
对于线段树的每一个单点修改,很容易就得到一次修改会影响到线段树上logN个点
那么,我们只要每次多开logN个点就可以记录下一种修改的状态了,这就是可持久化线段树,编起来不麻烦,但最重要的是应用起来要有点小技巧。
其实会使用到线段树的时候无非就是单点有修改时需要记录下多个版本的时候要用到,下面贴上几道例题
栗1
查询区间第k大(题目戳这里)
贴上CODE:
#include <algorithm> #include <iostream> #include <cstring> #include <string> #include <cstdio> #include <vector> #include <cmath> #include <queue> #include <map> using namespace std; const int maxn = 400010; struct lsh { int v, pm, id; } A[maxn]; bool _sort1(lsh x, lsh y) { return x.v < y.v; } bool _sort2(lsh x, lsh y) { return x.id < y.id; } int spac = 0, Root[maxn]; int lson[maxn*10], rson[maxn*10], cnt[maxn*10]; int build(int l, int r) { int x = ++spac; if(l == r) return x; int mid = (l+r)>>1; lson[x] = build(l, mid); rson[x] = build(mid+1, r); return x; } int Copy(int x) { int y = ++spac; lson[y] = lson[x], rson[y] = rson[x], cnt[y] = cnt[x]; return y; } int Insert(int x, int l, int r, int p) { int t=Copy(x); cnt[t]++; if(l == r) return t; int mid=(l+r)>>1; if(p<=mid) lson[t]=Insert(lson[x], l, mid, p); else rson[t]=Insert(rson[x], mid+1, r, p); return t; } int que(int a, int b, int l, int r, int k) { if(l == r) return l; int x=cnt[lson[b]]-cnt[lson[a]]; int mid = (l+r)>>1; if(x>=k) return que(lson[a], lson[b], l, mid, k); else return que(rson[a], rson[b], mid+1, r, k-x); } int getnum[maxn]; int main() { int n, Q; scanf("%d %d", &n, &Q); for (int i=1; i<=n; i++) scanf("%d", &A[i].v), A[i].id = i; sort(A+1, A+n+1, _sort1); int tot = 0; A[0].v = -10000; for (int i=1; i<=n; i++) { if(A[i].v!=A[i-1].v) tot++; A[i].pm = tot; getnum[tot] = A[i].v; } sort(A+1, A+n+1, _sort2); Root[0] = build(1, tot); for (int i=1; i<=n; i++) Root[i] = Insert(Root[i-1], 1, tot, A[i].pm); for (int i=1; i<=Q; i++) { int a, b, k; scanf("%d %d %d", &a, &b, &k); printf("%d\n", getnum[que(Root[a-1], Root[b], 1, tot, k)]); } return 0; }
栗2
可持久化区间修改(题目戳这里)
#include <algorithm> #include <iostream> #include <cstring> #include <string> #include <cstdio> #include <vector> #include <cmath> #include <queue> #include <map> using namespace std; const int maxn = 100010; const int oo = 2100000000; int spac = 0, Root[maxn], a[maxn]; long long laz[maxn*25]; int lson[maxn*25], rson[maxn*25]; long long cnt[maxn*25]; int build(int l, int r) { int x = ++spac; if(l == r) { cnt[x] = a[l]; return x; } int mid = (l+r)>>1; lson[x] = build(l, mid); rson[x] = build(mid+1, r); cnt[x] = cnt[lson[x]]+cnt[rson[x]]; return x; } int Copy(int x) { int y = ++spac; lson[y] = lson[x], rson[y] = rson[x], cnt[y] = cnt[x], laz[y] = laz[x]; return y; } int update(int x, int l, int r, int ll, int rr, long long p) { if(l>rr || r<ll) return x; int t=Copy(x); if(l>=ll && r<=rr) { laz[t]+=p, cnt[t]+=1LL*(r-l+1)*p; return t; } int mid=(l+r)>>1; lson[t]=update(lson[x], l, mid, ll, rr, p); rson[t]=update(rson[x], mid+1, r, ll, rr, p); cnt[t]=cnt[lson[t]] + cnt[rson[t]] + laz[t]*(r-l+1); return t; } long long que(int x, int l, int r, int ll, int rr, long long val) { if(l>rr || r<ll) return 0; if(l>=ll && r<=rr) return cnt[x]+1LL*(r-l+1)*val; val += laz[x]; int mid = (l+r)>>1; long long sum=0; sum += que(lson[x], l, mid, ll, rr, val); sum += que(rson[x], mid+1, r, ll, rr, val); return sum; } char ki[10]; int main() { int n, Q; while(scanf("%d%d", &n, &Q)!=EOF) { memset(a, 0, sizeof(a)); memset(Root, 0, sizeof(Root)); memset(cnt, 0, sizeof(cnt)); memset(lson, 0, sizeof(lson)); memset(laz, 0, sizeof(laz)); memset(rson, 0, sizeof(rson)); spac = 0; for (int i=1; i<=n; i++) scanf("%d", &a[i]); Root[0] = build(1, n); int tot=0, l, r; while(Q--) { scanf("%s", ki); if(ki[0] == 'B') { scanf("%d", &l); tot = l; continue; } scanf("%d %d", &l, &r); if(ki[0] == 'C') { long long w; scanf("%lld", &w); Root[++tot] = update(Root[tot-1], 1, n, l, r, w); continue; } if(ki[0] == 'Q') { printf("%lld\n", que(Root[tot], 1, n, l, r, 0LL)); continue; } if(ki[0] == 'H') { int w; scanf("%d", &w); printf("%lld\n", que(Root[w], 1, n, l, r, 0LL)); } } } return 0; }
这些题目具体的做法看一下程序就能懂了%_%