树状数组模板

树状数组

久仰大名
终于开始学习

主要用于使用\(O(\log{n})\)的复杂度实现单点修改和区间查询

先贴一个视频助于理解

然后是放oi wiki
树状数组

然后贴出来ai写的代码模板

template<typename T> // 便于使用long long, int 等不同类型的树状数组
struct BIT
{
    int n;                  // 数据范围(最大下标)
    vector<T> tr;           // 树状数组存储空间

    BIT(int n_)
    {
        n = n_;
        tr.assign(n_ + 1, 0);   // 下标从 1 开始,多开一个位置
    }

    static int lowbit(int x)
    {
        return x & -x;
    }

    // 单点更新:在下标 idx 处增加 delta
    void add(int idx, T delta)
    {
        while (idx <= n)
        {
            tr[idx] += delta;
            idx += lowbit(idx);
        }
    }

    // 前缀查询:求 [1..idx] 的和
    T pre(int idx)
    {
        T res = 0;
        while (idx > 0)
        {
            res += tr[idx];
            idx -= lowbit(idx);
        }
        return res;
    }

    // 区间查询:求 [l..r] 的和
    T query(int l, int r)
    {
        return pre(r) - pre(l - 1);
    }
};

然后贴一道洛谷的模板题
P3374 【模板】树状数组 1

然后贴ac代码

#include <bits/stdc++.h>
using namespace std;
using i32 = int;
using i64 = long long;

template<typename T>
struct BIT
{
    int n; vector<T> tr;

    BIT(int n_)
    {
        n = n_;
        tr.assign(n + 1, 0);
    }

    int lowbit(int x)
    {
        return x & -x;
    }

    void add(int idx, T delta)
    {
        while (idx <= n)
        {
            tr[idx] += delta;
            idx += lowbit(idx);
        }
    }

    T pref(int idx)
    {
        T ans = 0;
        while (idx > 0)
        {
            ans += tr[idx];
            idx -= lowbit(idx);
        }
        return ans;
    }

    T query(int l, int r)
    {
        return pref(r) - pref(l - 1);
    }
};


int main()
{
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    int n, m; cin >> n >> m;
    BIT<i32> bit = BIT<i32>(n);
    for (int i = 1; i <= n; ++ i)
    {
        int x; cin >> x;
        bit.add(i, x);
    }
    // for (int i = 1; i <= n; ++ i)
    //     cerr << bit.query(i, i) << ' ';
    // cerr << '\n';
    while (m -- )
    {
        int op; cin >> op;
        if (op == 1)
        {
            int x, k; cin >> x >> k;
            bit.add(x, k);
        }
        if (op == 2)
        {
            int x, y; cin >> x >> y;
            cout << bit.query(x, y) << '\n';
        }
        // for (int i = 1; i <= n; ++ i)
        //     cerr << bit.query(i, i) << ' ';
        // cerr << '\n';
    }
    return 0;
}

现在我要假装自己学会了

posted @ 2026-05-18 20:40  RonF02  阅读(3)  评论(0)    收藏  举报