Loading

历史和线段树

历史和线段树

历史和线段树就是能在 \(O(n\log n)\) 中查询过去 \(q\) 个版本某个区间的和的总和。

形式化的说,有一个数组 \(a\) 和一个辅助数组 \(b\),每一次(广义)更新操作都会执行 \(a:[x,y] \rightarrow b:[x,y]\),查询 \(k\) 个版本后 \(b:[x,y]\) 的值(即 \(\sum\limits_{k}\sum\limits_{i = x}^y a_i\) 的值)。

矩阵法求解历史和

如何在 \(O(n\log n)\) 的时间内解决上述问题呢?

我们考虑使用线段树和矩阵。

??? note "你可能需要一个优秀的matrix实现"

```cpp 
#include<bits/stdc++.h>
using namespace std;

template<int N,int M,class T = long long>
struct matrix {
	T m[N][M];
	matrix(){memset(m,0,sizeof(m));}
	void init(){for(int i = 0;i < N;i++)    m[i][i] = 1;} //初始化
	friend bool operator != (matrix<N,  M> &x,matrix<N,M> &y) {
		for(int i = 0;i<N;i++)
			for(int j = 0;j<M;j++)
				if(x[i][j] != y[i][j])
					return true;
		return false;
	}
	friend matrix<N,M> operator +=  (matrix<N,M> x,matrix<N,M> &y) {
		for(int i = 0;i<N;i++)
			for(int j = 0;j<M;j++)
				x[i][j] += y[i][j];
		return x;
	}
	int* operator [] (const int pos)    {return m[pos];}
	void print(string s) {
		cout<<'\n';
		string t = "test for " + s + "  matrix:";
		cout<<t<<'\n';
		for(int i = 0;i<N;i++)
			for(int j = 0;j<M;j++)
				cout<<m[i][j]<<" \n"[j  == M - 1];
		cout<<'\n';
	}
};

template<int N,int M,int R,class T =    long long>
matrix<N,R,T> operator * (matrix<N,M,T>     a,matrix<M,R,T> b) {
	matrix<N,R,T> c;
	for(int i = 0;i<N;i++)
		for(int j = 0;j<M;j++)
			for(int k = 0;k<R;k++)
				c[i][k] = c[i][k] + a[i]    [j] * b[j][k];
	return c;
}

template<int N,int M,class T = long long>
matrix<N,M,T> operator + (matrix<N,M,T>     a,matrix<N,M,T> b) {
	for(int i = 0;i<N;i++)
		for(int j = 0;j<M;j++)
			a[i][j] += b[i][j];
	return a;
}
template<int N,class T = long long>
matrix<N,N,T> qpow(matrix<N,N,T> x,int  k) {
	matrix<N,N,T> re;
	re.init();
	while(k){
		if(k & 1) re = re * x;
		x = x * x;
		k >>= 1;
	}
	return re;
}
```

???+note "LOJ193"
题目描述
这是一道模板题。
您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:
1.区间加一个数;
2. 查询区间的历史和;
历史和定义为数列 \(h_i\) 的区间和:初始 \(h_i=a_i\),在每次操作(修改或查询,具体可参考样例解释)完成后,对所有 \(h_i \leftarrow h_i+a_i\)

以 LOJ193 为例,我们要求最朴素的历史和。

我们可以用线段树维护矩阵,其中矩阵为:

\[\begin{bmatrix} his \\ sum \\ len \end{bmatrix} \]

其中 \(his\) 为历史和,\(sum\) 为区间和,\(len\) 为区间长度。

其实就是用矩阵打包线段树上要维护的所有变量。

对于叶子节点,\(len = 1\),\(sum = a_i\)\(his = 0\);对于非叶子节点,\(tag = \begin{bmatrix} 1 & 0 & 0\\ 0 & 1 & 0\\ 0 & 0 & 1\\ \end{bmatrix}\)

然后是对线段树进行区间矩阵乘操作:

节点的合并(例节点 \(a + b \rightarrow c\)):

\[\begin{bmatrix} his_a \\ sum_a \\ len_a \end{bmatrix} + \begin{bmatrix} his_b \\ sum_b \\ len_b \end{bmatrix} = \begin{bmatrix} his_c \\ sum_c \\ len_c \end{bmatrix} \]

区间加 \(d\) 操作为:

\[\begin{bmatrix} 1 & 0 & 0 \\ 0 & 1 & d \\ 0 & 0 & 1 \\ \end{bmatrix} \times \begin{bmatrix} his \\ sum \\ len \end{bmatrix} = \begin{bmatrix} his \\ sum + d \times len \\ len \end{bmatrix} \]

区间历史和更新操作为:

\[\begin{bmatrix} 1 & 1 & 0 \\ 0 & 1 & 0 \\ 0 & 0 & 1 \\ \end{bmatrix} \times \begin{bmatrix} his \\ sum \\ len \end{bmatrix} = \begin{bmatrix} his + sum \\ sum \\ len \end{bmatrix} \]

我们将矩阵按线段树的方式下放到指定区间即可。

我们每次进行区间加 \(d\) 时,对全局进行历史和更新操作

最后查询区间矩阵的历史和,只需按线段树的方式求区间矩阵的 \(\sum his\) 即可。

这样你就成功完成了此题!

??? note "解题代码?"

```cpp       
#include <bits/stdc++.h>
using namespace std;
#define int long long
int rd() {
    int x = 0, w = 1;
    char ch = 0;

    while (ch < '0' || ch > '9') {
        if (ch == '-')
            w = -1;

        ch = getchar();
    }

    while (ch >= '0' && ch <= '9') {
        x = x * 10 + (ch - '0');
        ch = getchar();
    }

    return x * w;
}
void wt(int x) {
    static int sta[35];
    int f = 1;

    if (x < 0)
        f = -1, x *= f;

    int top = 0;

    do {
        sta[top++] = x % 10, x /= 10;
    } while (x);

    if (f == -1)
        putchar('-');

    while (top)
        putchar(sta[--top] + 48);
}
template<int N, int M, class T = long   long>
struct matrix {
    int m[N][M];
    matrix() {
        memset(m, 0, sizeof(m));
    }
    void init() {
        for (int i = 0; i < N; i++)
            m[i][i] = 1;
    }
    friend bool operator != (matrix<N,  M> x, matrix<N, M> y) {
        for (int i = 0; i < N; i++)
            for (int j = 0; j < M; j++)
                if (x[i][j] != y[i][j])
                    return true;

        return false;
    }
    int *operator [](const int pos) {
        return m[pos];
    }
    void print(string s) {
        cout << '\n';
        string t = "test for " + s + "  matrix:";
        cout << t << '\n';

        for (int i = 0; i < N; i++)
            for (int j = 0; j < M; j++)
                cout << m[i][j] << " \n"    [j == M - 1];

        cout << '\n';
    }
};
template<int N, int M, int R, class T =     long long>
matrix<N, R, T> operator * (matrix<N, M,    T> a, matrix<M, R, T> b) {
    matrix<N, R, T> c;

    for (int i = 0; i < N; i++)
        for (int j = 0; j < M; j++)
            for (int k = 0; k < R; k++)
                c[i][k] = c[i][k] + a[i]    [j] * b[j][k];

    return c;
}
template<int N, int M, class T = long   long>
matrix<N, M, T> operator + (matrix<N, M,    T> a, matrix<N, M, T> b) {
    for (int i = 0; i < N; i++)
        for (int j = 0; j < M; j++)
            a[i][j] += b[i][j];

    return a;
}

const int N = 1e5 + 5;
int n, m, a[N];
namespace sgt {
matrix<3, 1> h[N << 2];
matrix<3, 3> tag[N << 2];
#define ls (p << 1)
#define rs (ls | 1)
#define mid ((pl + pr) >> 1)
void push_up(int p) {
    h[p] = h[ls] + h[rs];
}

void addtag(int p, matrix<3, 3> c) {
    h[p] = c * h[p] ;
    tag[p] = c * tag[p];
}

void push_down(int p) {
    matrix<3, 3> c;
    c.init();

    if (tag[p] != c) {
        addtag(ls, tag[p]);
        addtag(rs, tag[p]);
        tag[p] = c;
    }
}

void build(int p, int pl, int pr) {
    matrix<3, 3> c;
    c.init();
    tag[p] = c;

    if (pl == pr) {
        h[p][0][0] = h[p][1][0] = a[pl];
        h[p][2][0] = 1;
        return;
    }

    build(ls, pl, mid);
    build(rs, mid + 1, pr);
    push_up(p);
}

void update(int p, int pl, int pr, int  l, int r, matrix<3, 3> v) {
    if (l <= pl && pr <= r) {
        addtag(p, v);
        return;
    }

    push_down(p);

    if (l <= mid)
        update(ls, pl, mid, l, r, v);

    if (r > mid)
        update(rs, mid + 1, pr, l, r, v);

    push_up(p);
}

int query(int p, int pl, int pr, int l,     int r) {
    if (l <= pl && pr <= r)
        return h[p][0][0];

    push_down(p);
    int ans = 0;

    if (l <= mid)
        ans += query(ls, pl, mid, l, r);

    if (r > mid)
        ans += query(rs, mid + 1, pr, l,    r);

    return ans;
}

}

signed main() {
    n = rd(), m = rd();

    for (int i = 1; i <= n; i++)
        a[i] = rd();

    sgt::build(1, 1, n);
    auto upd = [&]() -> void {
        int l = rd(), r = rd(), x = rd();
        matrix<3, 3> c;
        c.init();
        c[1][2] = x;
        sgt::update(1, 1, n, l, r, c);
    };
    auto qry = [&]() -> void {
        int l = rd(), r = rd();
        wt(sgt::query(1, 1, n, l, r));
        putchar('\n');
    };

    while (m--) {
        int opt = rd();

        switch (opt) {
        case 1:
            upd();
            break;

        case 2:
            qry();
            break;

        default:
            puts("Error");
            exit(0);
            break;
        }

        matrix<3, 3> v;
        v.init();
        v[0][1] = 1;
        sgt::update(1, 1, n, 1, n, v);
    }

    return 0;
}
```

通过记录:accept?

进一步优化

我们的矩阵乘法要维护两个 \(3 \times 3\) 矩阵相乘的结果,这带来的结果是常数来到了惊人的 \(27\),然而这是无法接受的!

这时聪明的奶龙就发现了,矩阵的好多地方是不变的

我们可以用下面的代码来探究到底哪些矩阵元素永远不会变:
??? "探究随机矩阵乘所固定的元素"

```cpp   
#include<bits/stdc++.h>
using namespace std;

int rd() {
	int x = 0, w = 1;
	char ch = 0;
	while (ch < '0' || ch > '9') {
		if (ch == '-') w = -1;
		ch = getchar();
	}
	while (ch >= '0' && ch <= '9') {
		x = x * 10 + (ch - '0');
		ch = getchar();
	}
	return x * w;
}
void wt(int x) {
	static int sta[35];
	int f = 1;
	if(x < 0) f = -1,x *= f;
	int top = 0;
	do {
		sta[top++] = x % 10, x /= 10;
	} while (x);
	if(f == -1) putchar('-');
	while (top) putchar(sta[--top] + 48);
}

template<int N,int M,class T = long long>
struct matrix {
	int m[N][M];
	matrix(){memset(m,0,sizeof(m));}
	void init(){for(int i = 0;i < N;i++)    m[i][i] = 1;}
	friend bool operator != (matrix<N,M>    x,matrix<N,M> y) {
		for(int i = 0;i<N;i++)
			for(int j = 0;j<M;j++)
				if(x[i][j] != y[i][j])
					return true;
		return false;
	}
	int* operator [] (const int pos)    {return m[pos];}
	void print(string s) {
		cout<<'\n';
		string t = "test for " + s + "  matrix:";
		cout<<t<<'\n';
		for(int i = 0;i<N;i++)
			for(int j = 0;j<M;j++)
				cout<<m[i][j]<<" \n"[j  == M - 1];
		cout<<'\n';
	}
};
template<int N,int M,int R,class T =    long long>
matrix<N,R,T> operator * (matrix<N,M,T>     a,matrix<M,R,T> b) {
	matrix<N,R,T> c;
	for(int i = 0;i<N;i++)
		for(int j = 0;j<M;j++)
			for(int k = 0;k<R;k++)
				c[i][k] = c[i][k] + a[i]    [j] * b[j][k];
	return c;
}
template<int N,int M,class T = long long>
matrix<N,M,T> operator + (matrix<N,M,T>     a,matrix<N,M,T> b) {
	for(int i = 0;i<N;i++)
		for(int j = 0;j<M;j++)
			a[i][j] += b[i][j];
	return a;
}

template<int N,class T = long long>
matrix<N,N,T> qpow(matrix<N,N,T> x,int  k) {
	matrix<N,N,T> re;
	re.init(); 
	while(k) {
		if(k & 1) re = re * x;
		x = x * x;
		k >>= 1; 
	}
	return re;
}
matrix<3,3> re,b;
signed main() {
    re.init();
    while(1) {
        int c = rd();
        if(c == 0) return 0;
        if(c == 1) {
            b.init();
            int x = rd();
            b[1][2] = x;
            re = b * re;
            re.print("result:");
        }else if(c == 2) {
            b.init();
            b[0][1] = 1;
            re = b * re;
            re.print("result:");
        }
    }
	return 0;
}
```

我们会惊讶的发现,实际上矩阵中只有四个位置是在变化的:

\[\begin{bmatrix} 1 & a & b \\ 0 & c & d \\ 0 & 0 & 1 \\ \end{bmatrix} \]

那么,我们可以通过手摸矩阵来达到 \(3 \sim 4\) 的复杂度常数!

??? note "通过记录!"
```cpp
#include<bits/stdc++.h>
using namespace std;

int rd() {
	int x = 0, w = 1;
	char ch = 0;
	while (ch < '0' || ch > '9') {
		if (ch == '-') w = -1;
		ch = getchar();
	}
	while (ch >= '0' && ch <= '9') {
		x = x * 10 + (ch - '0');
		ch = getchar();
	}
	return x * w;
}
void wt(int x) {
	static int sta[35];
	int f = 1;
	if(x < 0) f = -1,x *= f;
	int top = 0;
	do {
		sta[top++] = x % 10, x /= 10;
	} while (x);
	if(f == -1) putchar('-');
	while (top) putchar(sta[--top] + 48);
}

struct tag{
    int x[7];
    void init() {
        x[1] = x[4] = x[6] = 1;
        x[2] = x[3] = x[5] = 0;
    }
    int& operator [](const int pos)     {return x[pos];}
    friend tag operator * (tag& A,tag&  B) {
        tag c;c.init();
        c[2] = A[2] + B[2];
        c[3] = B[3] + A[2] * B[5] + A[3];
        c[5] = B[5] + A[5];
        return c;
    }
    friend bool operator != (tag A,tag  B) {
        for(int i = 0;i<7;i++)
            if(A[i] != B[i])
                return true;
        return false;
    }
    void print(string s) {
        cout<<"test for "<<s<<"     matrix\n";
        cout<<x[1]<<' '<<x[2]<<' '<<x[3]    <<'\n';
        cout<<0<<' '<<x[4]<<' '<<x[5]   <<'\n';
        cout<<0<<' '<<0<<' '<<x[6]  <<'\n'; 
   }
};

struct vet{
    int y[4];
    void init() {y[1] = y[2] = y[3] = 0;}
    int& operator [](const int pos)     {return y[pos];}
    friend vet operator + (vet a,vet b) {
        vet c;c.init();
        c[1] = a[1] + b[1];
        c[2] = a[2] + b[2];
        c[3] = a[3] + b[3];
        return c;
    }
    void print(string s) {
        cout<<'\n';
        cout<<"test for "<<s<<"     vector\n";
        cout<<y[1]<<'\n';
        cout<<y[2]<<'\n';
        cout<<y[3]<<'\n';
        cout<<'\n';
    }
};

vet operator * (tag A,vet B) {
    vet c;c.init();
    c[1] = B[1] + B[2] * A[2] +B[3] * A [3];
    c[2] = B[2] + A[5] * B[3];
    c[3] = B[3];
    return c;
}

const int N = 1e5+5;
int n,m,a[N];

namespace sgt{
#define ls (p << 1)
#define rs (ls | 1)
#define mid ((pl + pr) >> 1)
tag T[N<<2];
vet t[N<<2];
void push_up(int p) {
    t[p] = t[ls] + t[rs];
}

void addtag(int p,tag x) {
    T[p] = x * T[p];
    t[p] = x * t[p];
}

void push_down(int p) {
    tag c;c.init();
    if(T[p] != c) {
        addtag(ls,T[p]);
        addtag(rs,T[p]);
        T[p].init();
    }
}

void build(int p,int pl,int pr) {
    T[p].init();
    if(pl == pr) {
        t[p][2] = t[p][1] = a[pl];
        t[p][3] = 1;
        return;
    }
    build(ls,pl,mid);
    build(rs,mid+1,pr);
    push_up(p);
}

void update(int p,int pl,int pr,int l,  int r,tag x) {
    if(l <= pl && pr <= r) {
        addtag(p,x);
        // t[p].print("upd");
        // T[p].print("upd");
        return;
    }
    push_down(p);
    if(l <= mid) update(ls,pl,mid,l,r,x);
    if(r > mid) update(rs,mid+1,pr,l,r, x);
    push_up(p);
}

int query(int p,int pl,int pr,int l,int     r) {
    if(l <= pl && pr <= r) return t[p]  [1];
    push_down(p);
    if(r <= mid) return query(ls,pl,mid,    l,r);
    else if(l > mid) return query(rs,mid    +1,pr,l,r);
    else return query(ls,pl,mid,l,r) +  query(rs,mid+1,pr,l,r);
}

}


signed main() {
    n = rd(),m = rd();
    for(int i = 1;i<=n;i++) a[i] = rd();
    sgt::build(1,1,n);
    while(m--) {
        int opt = rd();
        if(opt == 1) {
            int l = rd(),r = rd(),x = rd    ();
            tag c;c.init();
            c[5] = x;
            sgt::update(1,1,n,l,r,c);
        }else {
            int l = rd(),r = rd();
            wt(sgt::query(1,1,n,l,r));
            putchar('\n');
        }
        tag c;c.init();
        c[2] = 1;
        sgt::update(1,1,n,1,n,c);
    }
	return 0;
}
```

通过记录:accept!

可以看到区别还是很大的!

推理法求历史和

待后人补充!

值得注意的事情

矩阵所维护的元素所执行的操作无非加减乘除,这是矩阵的作为线性代数的性质导致的。

也就是说,历史和线段树只能用来维护具有线性关系的元素!

所有对于一类历史和线段树问题,思路都是尝试转换成一系列线性关系的操作。

习题

codeforces 1824D: LuoTianyi and the Function ~~~ 题解

NOIP2022 比赛

posted @ 2025-04-05 21:52  MingJunYi  阅读(259)  评论(0)    收藏  举报