JZOJ 7036. 2021.03.30【2021省赛模拟】凌乱平衡树(平衡树单旋+权值线段树)
JZOJ 7036. 2021.03.30【2021省赛模拟】凌乱平衡树
题目大意
- 给出两棵Treap,大小分别为 n , m n,m n,m,每个点的 p r i o r i t y priority priority值为子树大小(因此满足大根堆性质), Q Q Q次修改(修改是永久的),每次单旋一个节点,求修改前和每次修改后后两树合并之后的所有节点深度之和。合并按照Treap的合并方式,左树根为 x x x,右树根为 y y y时,当 s i z e x ≥ s i z e y size_x\ge size_y sizex≥sizey时以 x x x为根,否则反之。
- 1 ≤ n , m , Q ≤ 2 ∗ 1 0 5 1\le n,m,Q\le2*10^5 1≤n,m,Q≤2∗105
题解
- 考虑合并的过程,记录当前深度 d p dp dp,左树根每次向右走,就加上左儿子 F + G ∗ d p F+G*dp F+G∗dp,含义是所有点到左子树根的深度加上到实际的根深度差值。右边同理。这样需要在每次单旋后重新计算每个子树的大小 G G G及以该子树根为根的深度和 F F F。 G G G可以在常数复杂度内维护,但 F F F不行。
- 换一种思路,记录总的深度和 s u m sum sum,每次求出合并后增加的差值。这样合并的过程中,左树根每次向右走,则加上右树根的 G G G,含义是它子树内所有点的深度都会被增加 1 1 1。右边同理。
- 而合并时左树根始终向右,右树根始终向左,其它的节点是不会经过的,且与它相关的值也不会调用到,所以可以把左根向右和右根向左两条链(以下称为链)单独看,设链上 G G G序列左边依次为 A A A,右边为 B B B。 A i A_i Ai对答案的贡献次数为 ( A i , A i − 1 ] (A_i,A_{i-1}] (Ai,Ai−1]中 B B B的个数, B i B_i Bi对答案的贡献次数为 [ B i , B i − 1 ) [B_i,B_{i-1}) [Bi,Bi−1)中 A A A的个数,注意这里区间的开闭情况。
- 那么可以用权值线段树维护,把初始的 A A A和 B B B都存进同一棵权值线段树中,在单旋时进行修改。
- 只有两种情况需要修改:
- 1、单旋的节点 x x x和 x x x的父亲都在链中;
- 2、单旋的节点 x x x不在链中, x x x的父亲在链中。
- 修改时因为 G G G值会改变,所以需要先删除该点及其贡献,修改完 G G G后再加入回来。修改的贡献不仅有它自己的贡献,还有 A A A和 B B B中它们前驱的贡献。
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define N 200010
#define ll long long
struct {
int p[2];
}f[N * 4];
int ns;
ll ans;
void is(int v, int l, int r, int x, int o, int c) {
if(l == r) {
f[v].p[o] += c;
}
else {
int mid = (l + r) / 2;
if(x <= mid) is(v * 2, l, mid, x, o, c); else is(v * 2 + 1, mid + 1, r, x, o, c);
f[v].p[o] = f[v * 2].p[o] + f[v * 2 + 1].p[o];
}
}
int get(int v, int l, int r, int x, int y, int o) {
if(x > y) return 0;
if(l == x && r == y) return f[v].p[o];
int mid = (l + r) / 2;
if(y <= mid) return get(v * 2, l, mid, x, y, o);
if(x > mid) return get(v * 2 + 1, mid + 1, r, x, y, o);
return get(v * 2, l, mid, x, mid, o) + get(v * 2 + 1, mid + 1, r, mid + 1, y, o);
}
int find(int v, int l, int r, int x, int y, int k, int o) {
if(f[v].p[o] < k || x > y) return -1;
if(l == r) return l;
int mid = (l + r) / 2;
if(y <= mid) return find(v * 2, l, mid, x, y, k, o);
if(x > mid) return find(v * 2 + 1, mid + 1, r, x, y, k, o);
int s = get(v * 2, l, mid, x, mid, o);
if(s >= k) return find(v * 2, l, mid, x, mid, k, o);
return find(v * 2 + 1, mid + 1, r, mid + 1, y, k - s, o);
}
int find0(int v, int l, int r, int x, int y, int k, int o) {
if(f[v].p[o] < k || x > y) return -1;
if(l == r) return l;
int mid = (l + r) / 2;
if(y <= mid) return find0(v * 2, l, mid, x, y, k, o);
if(x > mid) return find0(v * 2 + 1, mid + 1, r, x, y, k, o);
int s = get(v * 2 + 1, mid + 1, r, mid + 1, y, o);
if(s >= k) return find0(v * 2 + 1, mid + 1, r, mid + 1, y, k, o);
return find0(v * 2, l, mid, x, mid, k - s, o);
}
ll count(int x, int o) {
if(!o) {
int t = find(1, 1, ns, x, ns, 2, 0);
if(t == -1) t = ns;
return (ll)get(1, 1, ns, x + 1, t, 1) * x;
}
else {
int t = find(1, 1, ns, x, ns, 2, 1);
if(t == -1) t = ns + 1;
return (ll)get(1, 1, ns, x, t - 1, 0) * x;
}
}
int fr(int x, int o) {
if(!o) {
int t = find0(1, 1, ns, 1, x, 1, 1);
return t == -1 ? 0 : t;
}
else {
int t = find0(1, 1, ns, 1, x - 1, 1, 0);
return t == -1 ? 0 : t;
}
}
struct {
int s, rt, p[N];
ll F[N], si[N], sum;
struct {
int s[2], fa, p;
}f[N];
void ins(int r, int l, int i) {
f[i].s[0] = l, f[i].s[1] = r;
f[l].fa = f[r].fa = i;
f[l].p = 0, f[r].p = 1;
}
void ro(int x, int o) {
int y = f[x].fa, z = f[y].fa, py = f[x].p, pz = f[y].p;
f[z].s[pz] = x, f[x].fa = z, f[x].p = pz;
f[y].s[py] = f[x].s[py ^ 1], f[f[x].s[py ^ 1]].fa = y, f[f[x].s[py ^ 1]].p = py;
f[x].s[py ^ 1] = y, f[y].fa = x, f[y].p = py ^ 1;
if(rt == y) rt = x;
int tp;
if(p[y] && p[x]) {
ans -= count(si[x], o) + count(si[y], o);
ans -= fr(si[y], o) + fr(si[x], o);
tp = find0(1, 1, ns, 1, si[x], 2, o);
if(tp > 0) ans -= count(tp, o);
is(1, 1, ns, si[y], o, -1);
is(1, 1, ns, si[x], o, -1);
}
else if(p[y] && !p[x]) {
ans -= count(si[y], o);
ans -= fr(si[y], o);
tp = find0(1, 1, ns, 1, si[y], 2, o);
if(tp > 0) ans -= count(tp, o);
is(1, 1, ns, si[y], o, -1);
}
si[y] = si[f[y].s[0]] + si[f[y].s[1]] + 1;
si[x] = si[f[x].s[0]] + si[f[x].s[1]] + 1;
sum += si[f[y].s[py ^ 1]] - si[f[x].s[py]];
if(p[y] && p[x]) {
is(1, 1, ns, si[x], o, 1);
ans += fr(si[x], o);
ans += count(si[x], o);
if(tp > 0) ans += count(tp, o);
p[y] = 0;
}
else if(p[y] && !p[x]) {
is(1, 1, ns, si[x], o, 1);
is(1, 1, ns, si[y], o, 1);
ans += fr(si[y], o) + fr(si[x], o);
ans += count(si[y], o) + count(si[x], o);
if(tp > 0) ans += count(tp, o);
p[x] = 1;
}
}
int find() {
for(int i = 1; i <= s; i++) if(f[i].fa == 0) return i;
}
void dfs(int k) {
F[k] = 1, si[k] = 1;
if(f[k].s[0]) dfs(f[k].s[0]), si[k] += si[f[k].s[0]], F[k] += F[f[k].s[0]] + si[f[k].s[0]];
if(f[k].s[1]) dfs(f[k].s[1]), si[k] += si[f[k].s[1]], F[k] += F[f[k].s[1]] + si[f[k].s[1]];
}
}a, b;
void solve() {
int x = a.rt;
while(x) is(1, 1, ns, a.si[x], 0, 1), a.p[x] = 1, x = a.f[x].s[1];
x = b.rt;
while(x) is(1, 1, ns, b.si[x], 1, 1), b.p[x] = 1, x = b.f[x].s[0];
ans = 0;
x = a.rt;
while(x) ans += count(a.si[x], 0) ,x = a.f[x].s[1];
x = b.rt;
while(x) ans += count(b.si[x], 1), x = b.f[x].s[0];
printf("%lld\n", ans + a.sum + b.sum);
}
int read() {
int s = 0;
char x = getchar();
while(x < '0' || x > '9') x = getchar();
while(x >= '0' && x <= '9') s = s * 10 + x - 48, x = getchar();
return s;
}
int main() {
int Q, i;
scanf("%d%d", &a.s, &b.s);
for(i = 1; i <= a.s; i++) {
a.ins(read(), read(), i);
}
for(i = 1; i <= b.s; i++) {
b.ins(read(), read(), i);
}
ns = max(a.s, b.s) + 1;
a.rt = a.find(), b.rt = b.find();
a.dfs(a.rt), b.dfs(b.rt);
a.sum = a.F[a.rt], b.sum = b.F[b.rt];
scanf("%d", &Q);
solve();
while(Q--) {
if(read() == 1) a.ro(read(), 0); else b.ro(read(), 1);
printf("%lld\n", ans + a.sum + b.sum);
}
return 0;
}
自我小结
- 细节比较多,各条语句中的顺序很重要,需要理清楚。
哈哈哈哈哈哈哈哈哈哈

浙公网安备 33010602011771号