Loading

PR#15. 二叉搜索树 做题记录

感觉有点厉害,学到了。link

如果把树换成序列,那么可能是三维偏序:

  • 时间一维

  • 位置一维

  • BST 键值一维

但是 BST 键值其实并不好分析,但是我想到了一种显然很假的维护方法:直接求出完整的 BST,然后求得每个点代表的区间,进行区间加。

所以不能直接维护,也就不能直接将其看成一维。

接下来很 Educational:把时间和 BST 键值合起来做,一个点 \(x\) 在 BST 上的到根路径上的点集为:\((1\sim x\) 中后缀最大值\()\cap (x\sim n\) 中前缀最大值\()\)

可以使用楼房重建线段树,所以就在 \(\mathcal O(n\log ^ 2n)\) 的时间解决了序列上的问题。

考虑树上问题。这里有个模型转化:二维偏序 \(\Rightarrow\) 序列线段树 \(\xrightarrow{\text{上树}}\) 树上线段树合并

所以可以树上差分后做线段树合并,时间复杂度 \(\mathcal O(n^2)\)


感觉这题还是比较 Educatianal 的,启示了 BST 的维护与楼房重建线段树的紧密联系。

我一开始的思路是序列上的三维偏序,但是却在 BST 路径维护上的想假了,这导致我浪费了很多时间。

所以,对于每一步思考,每一步的结论,都需要在草稿纸上验证一遍,而不是大脑空想,否则很容易走上一条不归路。


点击查看代码
#include <bits/stdc++.h>
namespace Initial {
	#define ll int
	#define ull unsigned ll
	#define fi first
	#define se second
	#define mkp make_pair
	#define pir pair <ll, ll>
	#define pb push_back
	#define i128 __int128
	using namespace std;
	const ll maxn = 2e5 + 10, inf = 1e9, mod = 998244353, iv = mod - mod / 2;
	ll power(ll a, ll b = mod - 2, ll p = mod) {
		ll s = 1;
		while(b) {
			if(b & 1) s = 1ll * s * a %p;
			a = 1ll * a * a %p, b >>= 1;
		} return s;
	}
	template <class T>
	const inline ll pls(const T x, const T y) { return x + y >= mod? x + y - mod : x + y; }
	template <class T>
	const inline void add(T &x, const T y) { x = x + y >= mod? x + y - mod : x + y; }
	template <class T>
	const inline void chkmax(T &x, const T y) { x = x < y? y : x; }
	template <class T>
	const inline void chkmin(T &x, const T y) { x = x < y? x : y; }
} using namespace Initial;

namespace Read {
	char buf[1 << 22], *p1, *p2;
	// #define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, (1 << 22) - 10, stdin), p1 == p2)? EOF : *p1++)
	template <class T>
	const inline void rd(T &x) {
		char ch; bool neg = 0;
		while(!isdigit(ch = getchar()))
			if(ch == '-') neg = 1;
		x = ch - '0';
		while(isdigit(ch = getchar()))
			x = (x << 1) + (x << 3) + ch - '0';
		if(neg) x = -x;
	}
} using Read::rd;

ll n, m, d[maxn][20], dep[maxn], ti[maxn], t;
vector <ll> to[maxn], ins[maxn], del[maxn];
struct Qry {ll x, w, k;} ;
vector <Qry> qry[maxn]; long long ans[maxn];

void dfs1(ll u, ll fa = 0) {
    dep[u] = dep[fa] + 1, d[u][0] = fa;
    for(ll i = 1; i < 20; i++) d[u][i] = d[d[u][i - 1]][i - 1];
    for(ll v: to[u])
        if(v ^ fa) dfs1(v, u);
}
ll lca(ll u, ll v) {
    if(dep[u] < dep[v]) swap(u, v);
    ll t = dep[u] - dep[v];
    for(ll i = 0; i < 20; i++)
        if(t & (1 << i)) u = d[u][i];
    if(u == v) return u;
    for(ll i = 19; ~i; i--)
        if(d[u][i] ^ d[v][i])
            u = d[u][i], v = d[v][i];
    return d[u][0];
}

const ll L = 2e5;
struct SGT {
    ll lc[maxn * 40], rc[maxn * 40], mn[maxn * 40], tot;
    long long sum1[maxn * 40], sum2[maxn * 40];
    ll query1(ll p, ll l, ll r, ll w) {
        if(w < mn[p] || !p) return 0;
        if(l == r) return l; ll mid = l + r >> 1;
        if(w > mn[lc[p]])
            return sum1[p] - sum1[lc[p]] + query1(lc[p], l, mid, w);
        return query1(rc[p], mid + 1, r, w);
    }
    ll query2(ll p, ll l, ll r, ll w) {
        if(w < mn[p] || !p) return 0;
        if(l == r) return l; ll mid = l + r >> 1;
        if(w > mn[rc[p]])
            return sum2[p] - sum2[rc[p]] + query2(rc[p], mid + 1, r, w);
        return query2(lc[p], l, mid, w);
    }
    void pushup(ll p, ll l, ll r) {
        mn[p] = min(mn[lc[p]], mn[rc[p]]); ll mid = l + r >> 1;
        sum1[p] = sum1[lc[p]] + query1(rc[p], mid + 1, r, mn[lc[p]]);
        sum2[p] = sum2[rc[p]] + query2(lc[p], l, mid, mn[rc[p]]);
    }
    void modify(ll &p, ll l, ll r, ll x) {
        if(!p) p = ++tot;
        if(l == r) return mn[p] = ti[x], sum1[p] = sum2[p] = x, void();
        ll mid = l + r >> 1;
        if(x <= mid) modify(lc[p], l, mid, x);
        else modify(rc[p], mid + 1, r, x);
        pushup(p, l, r);
    }
    void del(ll &p, ll l, ll r, ll x) {
        if(l == r) return p = 0, void();
        ll mid = l + r >> 1;
        if(x <= mid) del(lc[p], l, mid, x);
        else del(rc[p], mid + 1, r, x);
        if(!lc[p] && !rc[p]) p = 0;
        else pushup(p, l, r);
    }
    ll merge(ll p, ll q, ll l, ll r) {
        if(!p || !q) return p | q;
        if(l == r) return p; ll mid = l + r >> 1;
        lc[p] = merge(lc[p], lc[q], l, mid);
        rc[p] = merge(rc[p], rc[q], mid + 1, r);
        pushup(p, l, r); return p;
    }
    pir query1(ll p, ll l, ll r, ll ql, ll qr, ll w) {
        if(r < ql || qr < l || !p) return mkp(0, w);
        if(ql <= l && r <= qr) return mkp(query1(p, l, r, w), min(w, mn[p]));
        ll mid = l + r >> 1;
        pir A = query1(lc[p], l, mid, ql, qr, w);
        pir B = query1(rc[p], mid + 1, r, ql, qr, A.se);
        return mkp(A.fi + B.fi, B.se);
    }
    pir query2(ll p, ll l, ll r, ll ql, ll qr, ll w) {
        if(r < ql || qr < l || !p) return mkp(0, w);
        if(ql <= l && r <= qr) return mkp(query2(p, l, r, w), min(w, mn[p]));
        ll mid = l + r >> 1;
        pir A = query2(rc[p], mid + 1, r, ql, qr, w);
        pir B = query2(lc[p], l, mid, ql, qr, A.se);
        return mkp(A.fi + B.fi, B.se);
    }
    bool pd(ll p, ll l, ll r, ll x) {
        if(!p) return false;
        if(l == r) return true; ll mid = l + r >> 1;
        if(x <= mid) return pd(lc[p], l, mid, x);
        else return pd(rc[p], mid + 1, r, x);
    }
} tr; ll rt[maxn];

void dfs2(ll u, ll fa = 0) {
    for(ll v: to[u])
        if(v ^ fa)
            dfs2(v, u), rt[u] = tr.merge(rt[u], rt[v], 1, L);
    for(ll x: ins[u]) tr.modify(rt[u], 1, L, x);
    for(ll x: del[u]) tr.del(rt[u], 1, L, x);
    for(Qry p: qry[u]) {
        ans[p.k] += tr.query1(rt[u], 1, L, p.x, L, p.w).fi;
        ans[p.k] += tr.query2(rt[u], 1, L, 1, p.x, p.w).fi;
        if(ti[p.x] < p.w && tr.pd(rt[u], 1, L, p.x)) ans[p.k] -= p.x;
    }
}

int main() {
    rd(n), rd(m); tr.mn[0] = inf;
    for(ll i = 1; i < n; i++) {
        ll u, v; rd(u), rd(v);
        to[u].pb(v), to[v].pb(u);
    } dfs1(1);
    for(ll i = 1; i <= m; i++) {
        ll op, x, y, z; rd(op), rd(x), rd(y);
        if(op == 1) {
            rd(z); ti[z] = i;
            ins[x].pb(z), ins[y].pb(z);
            del[d[lca(x, y)][0]].pb(z);
        } else qry[x].pb((Qry) {y, i, ++t});
    } dfs2(1);
    for(ll i = 1; i <= t; i++) printf("%lld\n", ans[i]);
	return 0;
}
posted @ 2025-02-18 07:25  Sktn0089  阅读(12)  评论(0)    收藏  举报