codeforces常规线段树专项练习
以下是常规线段树模板,采用0-based、左闭右开写法,支持单点修改(含赋值set和增加add)、区间查询find、查找第一个满足条件元素findFirst、查找最后一个满足条件元素findLast。
template<class Val>
struct SegTree {
    int n = 0;
    std::vector<Val> val;
    void init(int _n, Val v = Val()) {
        std::vector<Val> tmp(_n, v);
        init(tmp.data(), tmp.size());
    }
    template<class T>
    void init(T *v, int _n) {
        n = _n;
        val.assign(4 << std::__lg(n), Val());
        std::function<void(int,int,int)> build = [&](int x, int l, int r) {
            if (l + 1 == r) {
                val[x] = Val(v[l]);
                return;
            }
            int m = (l + r) / 2;
            build(2*x+1, l, m);
            build(2*x+2, m, r);
            pushup(x);
        };
        build(0, 0, n);
    }
    void pushup(int x) {
        val[x] = val[2*x+1] + val[2*x+2];
    }
    void set(int x, int l, int r, int p, const Val &v) {
        if (l + 1 == r) {
            val[x] = Val(v);
            return;
        }
        int m = (l + r) / 2;
        if (p < m) {
            set(2*x+1, l, m, p, v);
        } else {
            set(2*x+2, m, r, p, v);
        }
        pushup(x);
    }
    void set(int p, const Val &v) {
        set(0, 0, n, p, v);
    }
    void add(int x, int l, int r, int p, const Val &v) {
        if (l + 1 == r) {
            val[x] += Val(v);
            return;
        }
        int m = (l + r) / 2;
        if (p < m) {
            add(2*x+1, l, m, p, v);
        } else {
            add(2*x+2, m, r, p, v);
        }
        pushup(x);
    }
    void add(int p, const Val &v) {
        add(0, 0, n, p, v);
    }
    Val find(int x, int l, int r, int L, int R) {
        if (R <= l || r <= L) {
            return Val();
        }
        if (L <= l && r <= R) {
            return val[x];
        }
        int m = (l + r) / 2;
        return find(2*x+1, l, m, L, R) + find(2*x+2, m, r, L, R);
    }
    Val find(int L, int R) {
        return find(0, 0, n, L, R);
    }
    template<class F>
    std::pair<int,Val> findFirst(int x, int l, int r, int L, int R, F pred) {
        if (R <= l || r <= L || !pred(val[x])) {
            return {-1, Val()};
        }
        if (l + 1 == r) {
            return {l, val[x]};
        }
        int m = (l + r) / 2;
        std::pair<int,Val> res = findFirst(2*x+1, l, m, L, R, pred);
        if (res.first == -1) {
            res = findFirst(2*x+2, m, r, L, R, pred);
        }
        return res;
    }
    template<class F>
    std::pair<int,Val> findFirst(int L, int R, F pred) {
        return findFirst(0, 0, n, L, R, pred);
    }
    template<class F>
    std::pair<int,Val> findLast(int x, int l, int r, int L, int R, F pred) {
        if (R <= l || r <= L || !pred(val[x])) {
            return {-1, Val()};
        }
        if (l + 1 == r) {
            return {l, val[x]};
        }
        int m = (l + r) / 2;
        std::pair<int,Val> res = findLast(2*x+2, m, r, L, R, pred);
        if (res.first == -1) {
            res = findLast(2*x+1, l, m, L, R, pred);
        }
        return res;
    }
    template<class F>
    std::pair<int,Val> findLast(int L, int R, F pred) {
        return findLast(0, 0, n, L, R, pred);
    }
};
1A:单点修改,区间求和。
#include <bits/stdc++.h>
using i64 = long long;
// segtree模板。。。
struct Val {
    i64 sum;
    Val():sum(0) {}
    Val(i64 v):sum(v) {}
    friend Val operator+(const Val &a, const Val &b) {
        Val ans;
        ans.sum = a.sum + b.sum;
        return ans;
    }
};
void solve() {
    int n, m;
    std::cin >> n >> m;
    std::vector<int> a(n);
    for (int i = 0; i < n; i++) {
        std::cin >> a[i];
    }
    SegTree<Val> st;
    st.init(a.data(), a.size());
    for (int i = 0; i < m; i++) {
        int t, x, y;
        std::cin >> t >> x >> y;
        if (t == 1) {
            st.set(x, y);
        } else {
            std::cout << st.find(x, y).sum << '\n';
        }
    }
}
int main() {
    std::cin.tie(0)->sync_with_stdio(0);
    int t = 1;
    while (t--) solve();
    return 0;
}
1B:单点修改,求区间最小。
#include <bits/stdc++.h>
using i64 = long long;
// segtree模板。。。
struct Val {
    i64 min;
    Val():min(1E18) {}
    Val(i64 v):min(v) {}
    friend Val operator+(const Val &a, const Val &b) {
        Val ans;
        ans.min = std::min(a.min, b.min);
        return ans;
    }
};
void solve() {
    int n, m;
    std::cin >> n >> m;
    std::vector<int> a(n);
    for (int i = 0; i < n; i++) {
        std::cin >> a[i];
    }
    SegTree<Val> st;
    st.init(a.data(), a.size());
    for (int i = 0; i < m; i++) {
        int t, x, y;
        std::cin >> t >> x >> y;
        if (t == 1) {
            st.set(x, y);
        } else {
            std::cout << st.find(x, y).min << '\n';
        }
    }
}
int main() {
    std::cin.tie(0)->sync_with_stdio(0);
    int t = 1;
    while (t--) solve();
    return 0;
}
1C:单点修改,求区间最小以及次数
#include <bits/stdc++.h>
using i64 = long long;
// segtree板板。。。
struct Val {
    int min, cnt;
    Val():min(2E9),cnt(0) {}
    Val(int v):min(v),cnt(1) {}
    friend Val operator+(const Val &a, const Val &b) {
        Val ans;
        if (a.min < b.min) {
            return a;
        }
        if (b.min < a.min) {
            return b;
        }
        ans.min = a.min;
        ans.cnt = a.cnt + b.cnt;
        return ans;
    }
};
void solve() {
    int n, m;
    std::cin >> n >> m;
    std::vector<int> a(n);    
    for (int i = 0; i < n; i++) {
        std::cin >> a[i];
    }
    SegTree<Val> st;
    st.init(a.data(), a.size());
    for (int i = 0; i < m; i++) {
        int z, x, y;
        std::cin >> z >> x >> y;
        if (z == 1) {
            st.set(x, y);
        } else if (z == 2) {
            Val res = st.find(x, y);
            std::cout << res.min << " " << res.cnt << "\n";
        }
    }
    int ans = 0;
    for (int i = 0; i < n; i++) {
        int cnt = st.find(i, i+1).cnt;
        if (cnt > 0) {
            ans++;
        }
    }
}
int main() {
    std::cin.tie(0)->sync_with_stdio(0);
    int t = 1;
    while (t--) solve();
    return 0;
}
2A:单点修改,求最大子段和。
#include <bits/stdc++.h>
using i64 = long long;
// segtree模板。。。
struct Val {
    i64 pre, sum, suf, max;
    Val():pre(0), sum(0), suf(0), max(0) {}
    Val(i64 v):pre(v), sum(v), suf(v), max(std::max(0LL, v)) {}
    friend Val operator+(const Val &a, const Val &b) {
        Val ans;
        ans.pre = std::max(a.pre, a.sum + b.pre);
        ans.sum = a.sum + b.sum;
        ans.suf = std::max(b.suf, b.sum + a.suf);
        ans.max = std::max({a.max, b.max, a.suf + b.pre});
        return ans;
    }
};
void solve() {
    int n, m;
    std::cin >> n >> m;
    std::vector<i64> a(n);
    for (int i = 0; i < n; i++) {
        std::cin >> a[i];
    }
    SegTree<Val> st;
    st.init(a.data(), a.size());
    std::cout << st.find(0, n).max << "\n";
    for (int i = 0; i < m; i++) {
        int x, y;
        std::cin >> x >> y;
        st.set(x, y);
        std::cout << st.find(0, n).max << "\n";
    }
}
int main() {
    std::cin.tie(0)->sync_with_stdio(0);
    int t = 1;
    while (t--) solve();
    return 0;
}
2B:给定01串,单点修改,求第k个1的下标。
#include <bits/stdc++.h>
using i64 = long long;
// segtree模板。。。
struct Val {
    int sum;
    Val():sum(0) {}
    Val(int v):sum(v) {}
    friend Val operator+(const Val &a, const Val &b) {
        Val ans;
        ans.sum = a.sum + b.sum;
        return ans;
    }
};
void solve() {
    int n, m;
    std::cin >> n >> m;
    std::vector<int> a(n);
    for (int i = 0; i < n; i++) {
        std::cin >> a[i];
    }
    SegTree<Val> st;
    st.init(a.data(), a.size());
    for (int i = 0; i < m; i++) {
        int x, y;
        std::cin >> x >> y;
        if (x == 1) {
            int v = st.find(y, y + 1).sum;
            st.set(y, 1 - v);
        } else if (x == 2) {
            int lo = 0, hi = n - 1;
            while (lo < hi) {
                int mid = lo + (hi - lo) / 2;
                if (st.find(0, mid + 1).sum > y) {
                    hi = mid;
                } else {
                    lo = mid + 1;
                }
            }
            std::cout << lo << '\n';
        }
    }
}
int main() {
    std::cin.tie(0)->sync_with_stdio(0);
    int t = 1;
    while (t--) solve();
    return 0;
}
2C:单点修改,求第1个大于等于指定值的数的下标。
#include <bits/stdc++.h>
using i64 = long long;
// segtree模板。。。
struct Val {
    int max;
    Val():max(-1E9) {}
    Val(int v):max(v) {}
    friend Val operator+(const Val &a, const Val &b) {
        Val ans;
        ans.max = std::max(a.max, b.max);
        return ans;
    }
};
void solve() {
    int n, m;
    std::cin >> n >> m;
    std::vector<int> a(n);
    for (int i = 0; i < n; i++) {
        std::cin >> a[i];
    }
    SegTree<Val> st;
    st.init(a.data(), a.size());
    for (int i = 0; i < m; i++) {
        int op;
        std::cin >> op;
        if (op == 1) {
            int u, v;
            std::cin >> u >> v;
            st.set(u, v);
        } else if (op == 2) {
            int x;
            std::cin >> x;
            std::cout << st.findFirst(0, n, [&](const Val &v) { return v.max >= x; }).first << '\n';
        }
    }
}
int main() {
    std::cin.tie(0)->sync_with_stdio(0);
    int t = 1;
    while (t--) solve();
    return 0;
}
2D:单点修改,求指定区间内第1个大于等于指定值的数的下标。
#include <bits/stdc++.h>
using i64 = long long;
// segtree模板。。。
struct Val {
    int max;
    Val():max(-1E9) {}
    Val(int v):max(v) {}
    friend Val operator+(const Val &a, const Val &b) {
        Val ans;
        ans.max = std::max(a.max, b.max);
        return ans;
    }
};
void solve() {
    int n, m;
    std::cin >> n >> m;
    std::vector<int> a(n);
    for (int i = 0; i < n; i++) {
        std::cin >> a[i];
    }
    SegTree<Val> st;
    st.init(a.data(), a.size());
    for (int i = 0; i < m; i++) {
        int op, x, y;
        std::cin >> op >> x >> y;
        if (op == 1) {
            st.set(x, y);
        } else {
            std::cout << st.findFirst(y, n, [&](const Val &v) { return v.max >= x; }).first << '\n';
        }
    }
}
int main() {
    std::cin.tie(0)->sync_with_stdio(0);
    int t = 1;
    while (t--) solve();
    return 0;
}
                    
                
                
            
        
浙公网安备 33010602011771号