另类的懒标记形式——矩阵乘法

“数学的本质在于它的自由”


2025.9.7 upd: cpu的代码被强数据卡爆了 遂决定在博客里添加常数优化的参考代码


前言

膜拜xde将这种原本冷门的技巧发扬光大

这是一篇关于线段树懒标记矩阵化的博客,内容大部分仅是我个人的理解,由于笔者实力十分有限,对矩阵和懒标记的理解只浮于表面,本文以及我所有的博客最主要的作用仅是记一个我能看懂的代码和自己的一些理解,欢迎大佬指正下文中我的错误

想要更优秀的博客?


Pert0 引入

有时我们在用线段树处理题目时,需要设计多个标记,而标记之间的运算与优先级常常需要细致的讨论,不光耽误时间也容易出错,那么我们考虑能否更高效的推导如何维护标记与数据,也就是本文的主题——线段树维护矩阵

Part1 从基础开始

考虑最基础的问题:用线段树维护区间和

对于每个区间,维护向量
[sumlen]
同样的,我们将每个区间的懒标记化为[1c01]
那么标记传递时,我们有[sum+len×clen]=[sumlen][1c01]

这样,我们成功将最基础的线段树操作化为了矩阵运算

需要注意的是,由于矩阵运算自带常数,有时并不能通过题目,对于这种情况,我们可以不保留常数(后文代码为了便于理解没有这个优化)

Part2 区间历史最大值

显然如果我们只用矩阵解决那么简单的问题是多此一举,下面我们来看例题

此题需要维护区间最大值与区间历史最大值,如果我们用传统懒标记记录将会异常麻烦,毕竟普通懒标记会将操作合并,然而合并后的懒标记无法用来更新历史最大值,但我们运用矩阵运算可在合并懒标记的同时更新历史最大值

首先定义此题矩阵乘为ci,j=max{ai,k,bk,j}

具体的,我们对于一段区间维护[ab0]
其中a表示当前最大值,b表示历史最大值

对于区间加操作,我们维护形如[kk00]的矩阵,则对区间答案的更新为[kk00][ab0]=[a+kmax{b,a+k}0]
而赋值操作为[k0k0][ab0]=[kmax{b,k}0]
于是在下传标记的同时,我们简洁地维护了区间历史最值

本题代码

点击查看代码
#include<bits/stdc++.h>
#define pb push_back
#define N 100010
#define ll long long
#define ull unsigned long long
#define ls x<<1
#define rs x<<1|1
using namespace std;
int n,T,c;
ll a[N];
const ll inf=1e15;
struct tree {
    ll z[3][1];//第一列当前最大值 第二列历史最大值 第三列常数0
    tree() {  z[0][0]=-inf;z[1][0]=-inf;z[2][0]=0; }
    tree operator+(const tree& p) const{
        tree c;
        for(int i=0;i<3;i++) c.z[i][0]=max(z[i][0],p.z[i][0]);
        return c;
    }
}b[N<<2];
struct tag {
    ll z[3][3];
    tag() { //构造函数,需要注意的是这样会使编译特别慢,后续计算答案或赋值也不方便,所以不建议使用
        for(int i=0;i<3;i++)
            for(int j=0;j<3;j++) z[i][j]=-inf;
        for(int i=0;i<3;i++) z[i][i]=0;
    }tag operator*(const tag& p) {
        tag c;
        for(int i=0;i<3;i++) 
            for(int j=0;j<3;j++){
                ll tmp=-inf;
                for(int k=0;k<3;k++) tmp=max(tmp,z[i][k]+p.z[k][j]); //懒标记间运算
                c.z[i][j]=tmp;
            }       
        return c;
    }void mem() {
        for(int i=0;i<3;i++)
            for(int j=0;j<3;j++) 
                z[i][j]=-inf;
        for(int i=0;i<3;i++) z[i][i]=0;
    }void fz1(ll k) { //覆盖初始懒标记
        z[0][0]=-inf;z[0][2]=k;z[1][2]=k;
    }void fz2(ll k) { //区间加初始懒标记
        z[0][0]=k;z[1][0]=k;
    }tree operator*(const tree& p) const{
        tree c;
        for(int i=0;i<3;i++) 
            for(int j=0;j<1;j++) {
                ll tmp=-inf;for(int k=0;k<3;k++) tmp=max(tmp,z[i][k]+p.z[k][j]); //更新答案
                c.z[i][j]=tmp;
            }
        return c;
    }
}d[N<<2];
void lt(int x,tag c) { b[x]=c*b[x];d[x]=c*d[x];}
void mg(int x) {b[x]=b[ls]+b[rs];}
void pd(int x) {
    lt(ls,d[x]);lt(rs,d[x]);d[x].mem();
}void build(int x,int l,int r) {
    if(l==r) {
        b[x].z[0][0]=a[l];b[x].z[1][0]=a[l];
        return ;
    }int mid=(l+r)>>1;
    build(ls,l,mid);build(rs,mid+1,r);mg(x);
}void mdf(int x,int s,int t,int l,int r,tag c) {
    if(s>=l&&t<=r) {
        lt(x,c);return ;
    }int mid=(s+t)>>1;pd(x);
    if(l<=mid) mdf(ls,s,mid,l,r,c);
    if(r>mid) mdf(rs,mid+1,t,l,r,c);mg(x);
}tree query(int x,int s,int t,int l,int r) {
    if(s>=l&&t<=r) { return b[x];}
    int mid=(s+t)>>1;pd(x);tree ans;
    if(l<=mid) ans=ans+query(ls,s,mid,l,r);
    if(r>mid) ans=ans+query(rs,mid+1,t,l,r);mg(x);
    return ans;
}signed main() {
    ios::sync_with_stdio(false);
    cin.tie(0);cout.tie(0);
    cin>>n;for(int i=1;i<=n;i++) cin>>a[i];
    int Q;cin>>Q;char op;ll x,y,z;build(1,1,n);
    while(Q--) {
        cin>>op;
        if(op=='Q') {
            cin>>x>>y;printf("%lld\n",query(1,1,n,x,y).z[0][0]);
        }else if(op=='A') {
            cin>>x>>y;printf("%lld\n",query(1,1,n,x,y).z[1][0]);
        }else if(op=='P') {
            cin>>x>>y>>z;tag c;c.fz2(z);
            mdf(1,1,n,x,y,c);
        }else {
            cin>>x>>y>>z;tag c;c.fz1(z);
            mdf(1,1,n,x,y,c);
        }
    }
    return 0;
}

常数优化版:

点击查看代码
#include<bits/stdc++.h>
#define pb push_back
#define N 100010
#define ll long long
#define ull unsigned long long
#define ls x<<1
#define rs x<<1|1
using namespace std;
int n,T,c;
ll a[N];
const ll inf=1e14;
struct tree {
    ll m,h;
    void fz() { m=-inf;h=-inf; }
    tree operator+(const tree& p) const{
        tree c;
        c.m=max(m,p.m);c.h=max(h,p.h);
        return c;
    }
}b[N<<2];
struct tag {
    ll x11,x21,x13,x23;
    void fz() {x11=0;x21=-inf;x13=-inf;x23=-inf;}
    tag operator*(const tag& p) {
        tag c;c.fz();
        c.x11=x11+p.x11;c.x21=max(p.x11+x21,p.x21);c.x13=max(x11+p.x13,x13);c.x23=max(x21+p.x13,max(p.x23,x23));      
        return c;
    }void fz1(ll k) { //覆盖
        fz();x11=-inf;
        x13=k;x23=k;
    }void fz2(ll k) { //区间加
        fz();
        x11=k;x21=k;
    }tree operator*(const tree& p) const{
        tree c;c.fz();
        c.m=max(x11+p.m,x13);c.h=max(p.m+x21,max(p.h,x23));
        return c;
    }
}d[N<<2];
void lt(int x,tag c) { b[x]=c*b[x];d[x]=c*d[x];}
void mg(int x) {b[x]=b[ls]+b[rs];}
void pd(int x) {
    lt(ls,d[x]);lt(rs,d[x]);d[x].fz();
}void build(int x,int l,int r) {
    d[x].fz();b[x].fz();
    if(l==r) {
        b[x].m=a[l];b[x].h=a[l];
        return ;
    }int mid=(l+r)>>1;
    build(ls,l,mid);build(rs,mid+1,r);mg(x);//printf("%lld %d %d\n",b[x].h,l,r);
}void mdf(int x,int s,int t,int l,int r,tag c) {
    if(s>=l&&t<=r) {
        lt(x,c);return ;
    }int mid=(s+t)>>1;pd(x);
    if(l<=mid) mdf(ls,s,mid,l,r,c);
    if(r>mid) mdf(rs,mid+1,t,l,r,c);mg(x);
}tree query(int x,int s,int t,int l,int r) {
    if(s>=l&&t<=r) { return b[x];}
    int mid=(s+t)>>1;tree ans;ans.fz();pd(x);
    if(l<=mid) ans=ans+query(ls,s,mid,l,r);
    if(r>mid) ans=ans+query(rs,mid+1,t,l,r);mg(x);
    
    return ans;
}signed main() {
    ios::sync_with_stdio(false);
    cin.tie(0);cout.tie(0);
    cin>>n;for(int i=1;i<=n;i++) cin>>a[i];
    int Q;cin>>Q;char op;ll x,y,z;build(1,1,n);
    while(Q--) {
        cin>>op;
        if(op=='Q') {
            cin>>x>>y;printf("%lld\n",query(1,1,n,x,y).m);
        }else if(op=='A') {
            cin>>x>>y;printf("%lld\n",query(1,1,n,x,y).h);
        }else if(op=='P') {
            cin>>x>>y>>z;tag c;c.fz2(z);
            mdf(1,1,n,x,y,c);
        }else {
            cin>>x>>y>>z;tag c;c.fz1(z);
            mdf(1,1,n,x,y,c);
        }
    }
    return 0;
}

这个题有个小细节,如果不在代码里判tag里的数和inf的关系的话inf乱赋很容易爆掉,不过我换了一次inf就过了。。。

Part3 区间历史和

例题

此题裸暴力可获得高达8分的成绩

考虑优化,将询问离线,从1到n枚举r

设计数组h,对于每个r有hi=j=irXi,jYi,j

也就是固定左端点的情况下每个子区间的贡献和,其中X与Y为区间内最大值

则对于每个右端点为当前r的询问,答案为i=lrhi
时间复杂度为O(n2+qn)
这样我们就获得了20分!

点击查看代码
#include<bits/stdc++.h>
#define pb push_back
#define N 3010
#define ll long long
#define ull unsigned long long
#define ls x<<1
#define rs x<<1|1
using namespace std;
int n,m,c,st[N],ct;
ull a[N],b[N],ans[N],h[N],X[N][N],Y[N][N];
struct node{ int l,id; };
vector<node> v[N];
signed main() {
    ios::sync_with_stdio(false);
    cin.tie(0);cout.tie(0);
    cin>>c>>n;
    for(int i=1;i<=n;i++) cin>>a[i];
    for(int i=1;i<=n;i++) cin>>b[i];
    for(int i=1;i<=n;i++) 
        for(int j=i;j<=n;j++) 
            X[i][j]=max(X[i][j-1],a[j]),Y[i][j]=max(Y[i][j-1],b[j]);
    cin>>m;int x,y;
    for(int i=1;i<=m;i++) cin>>x>>y,v[y].push_back({x,i});
    for(int i=1;i<=n;i++) {
        for(int j=1;j<=i;j++) h[j]+=X[j][i]*Y[j][i];
        for(int k=0;k<v[i].size();k++) {
            int id=v[i][k].id,l=v[i][k].l;
            for(int j=l;j<=i;j++) ans[id]+=h[j];
        } 
            
    }
    for(int i=1;i<=m;i++) printf("%llu\n",ans[i]);
    return 0;
}

继续优化,我们发现瓶颈在于h数组的更新与求和以及最大值的更新

我们考虑用线段树维护h的和,现在思考当r+1时如何更新h

回顾暴力代码,可以使用单调栈优化,记录每个元素的值,下标,以及左边第一个大于此元素的下标,那么对于所有在这两个下标间的h,当前元素会产生贡献

我们考虑用线段树维护矩阵[ababhlen]
其中Ai=maxaj,j[i,r]Bi=maxbj,j[i,r]a=i=lrAib=i=lrBiab=i=lrAiBi

而c为ab的历史和,len为区间长度

接下来就很好做了,每次我们枚举到一个新的r时,用单调栈更新每个元素所“管辖”的区域,如果一个元素被弹掉了,代表它所在的区间中所有h要用当前r位置的元素更新,我们可以用线段树进行区间加更新,即将此区间原本最大值加上增量,具体的,有[1000k010000k1000001000001][ababclen]=[a+k×lenbab+k×bclen][100000100kk01000001000001][ababclen]=[ab+k×lenab+k×aclen][1000001000001000011000001][ababclen]=[ababc+ablen]

注意到,此处处理版本历史和的精髓在于每次操作都下传了一个更新版本历史和的标记,我刚学的时候直接把这个地方忽略了,从而不理解正确性(太糖了)。

本题矩阵常数较大,如果不优化难以通过,于是我们只维护非常数的位置。同时要注意,在标记之间运算时也会出现非常数的位置,这里我懒得再列标记运算的过程了,可以自己算一下,于是常数就被我们优化好了

下面给出代码 可读性应该还算高(吧?)

点击查看代码
#include<bits/stdc++.h>
#define pb push_back
#define N 250010
#define ll long long
#define ull unsigned long long
#define ls x<<1
#define rs x<<1|1
using namespace std;
int n,m,c,ca,cb;
ull A[N],B[N],ans[N];
struct Q{ int l,id; };
vector<Q> v[N];
struct node{int l,r;ull val;}sa[N],sb[N];
struct tree{
    ull a,b,ab,h,len;
    void fz() { a=0;b=0;ab=0;h=0;len=1; }
    tree operator+(const tree& p) const{ return {a+p.a,b+p.b,ab+p.ab,h+p.h,len+p.len};} //线段树合并
}b[N<<2];
struct tag{
    ull x15,x32,x25,x31,x43,x35,x41,x42,x45;
    void fz() { x15=0;x32=0;x25=0;x31=0;x43=0;x35=0;x41=0;x42=0;x45=0;}
    void fz1(ull c) {fz();x15=c;x32=c;}
    void fz2(ull c) {fz();x25=c;x31=c;}
    void fz3() {fz();x43=1;}
    tag operator*(const tag& p) const{ //标记间运算。。比较难写 很容易写错
        tag z;z.fz();
        z.x15=p.x15+x15;
        z.x32=p.x32+x32;
        z.x25=p.x25+x25;
        z.x31=p.x31+x31;
        z.x43=p.x43+x43;
        z.x35=p.x35+x35+x31*p.x15+x32*p.x25;
        z.x41=p.x41+x41+x43*p.x31;
        z.x42=p.x42+x42+x43*p.x32;
        z.x45=p.x45+x45+x43*p.x35+x41*p.x15+x42*p.x25;
        return z;
    }tree operator*(const tree& p) const{ //标记更新答案
        tree z;
        z.a=p.a+p.len*x15;z.b=p.b+x25*p.len;
        z.ab=x31*p.a+x32*p.b+p.ab+x35*p.len;
        z.h=x41*p.a+x42*p.b+x45*p.len+p.h+p.ab*x43;  z.len=p.len;
        return z;
    }
}d[N<<2];
void mg(int x) {b[x]=b[ls]+b[rs];} 
void lt(int x,tag c) { b[x]=c*b[x];d[x]=c*d[x];}
void pd(int x) {
    lt(ls,d[x]);lt(rs,d[x]);d[x].fz();
}void build(int x,int l,int r) {
    if(l==r) {
        b[x].fz();d[x].fz();return ;
    }int mid=(l+r)>>1;
    build(ls,l,mid);build(rs,mid+1,r);b[x].fz();mg(x);d[x].fz();
}void mdf(int x,int s,int t,int l,int r,tag c) {
    if(s>=l&&t<=r) {lt(x,c);return ;}
    int mid=(s+t)>>1;pd(x);
    if(mid>=l) mdf(ls,s,mid,l,r,c);
    if(mid<r) mdf(rs,mid+1,t,l,r,c);mg(x);
}tree query(int x,int s,int t,int l,int r) {
    if(s>=l&&t<=r) {return b[x];}
    int mid=(s+t)>>1;pd(x);tree ans;ans.fz();
    if(mid>=l) ans=ans+query(ls,s,mid,l,r);
    if(mid<r) ans=ans+query(rs,mid+1,t,l,r);
    return ans;
}signed main() {
    ios::sync_with_stdio(false);
    cin.tie(0);cout.tie(0);cin>>c>>n;
    for(int i=1;i<=n;i++) cin>>A[i];
    for(int i=1;i<=n;i++) cin>>B[i];
    cin>>m;int x,y;build(1,1,n);
    for(int i=1;i<=m;i++) cin>>x>>y,v[y].push_back({x,i});
    sa[0]={0,0,0};sb[0]={0,0,0};tag c;
    for(int i=1;i<=n;i++) {
        while(ca&&sa[ca].val<A[i]) {c.fz1(A[i]-sa[ca].val),mdf(1,1,n,sa[ca].l,sa[ca].r,c),ca--;}//单调栈更新之前的h
        while(cb&&sb[cb].val<B[i]) {c.fz2(B[i]-sb[cb].val),mdf(1,1,n,sb[cb].l,sb[cb].r,c),cb--;}
        c.fz1(A[i]);mdf(1,1,n,i,i,c);//新位置更新
        c.fz2(B[i]);mdf(1,1,n,i,i,c);
        c.fz3();mdf(1,1,n,1,i,c);//更新所有h的答案
        sa[ca+1]={sa[ca].r+1,i,A[i]};sb[cb+1]={sb[cb].r+1,i,B[i]};ca++;cb++;//把当前位置压进栈
        for(int j=0;j<v[i].size();j++) ans[v[i][j].id]+=query(1,1,n,v[i][j].l,i).h;//计算答案
    }
    for(int i=1;i<=m;i++) printf("%llu\n",ans[i]);
    return 0;
}

Part4 后记

这篇博客写的也一如既往的简陋。。其实还是题见的太少,等后面做更多的题后有新的感悟话会再补充

感叹第一个提出这种思想的人一定是天才吧,其实也不是什么很高深的东西,但学会后还是很开心,希望以后学习也能一直保持这种鲜活的心情,而不是因为残酷的现实学麻了......

posted @ 2025-09-03 19:21  he_qwq  阅读(24)  评论(0)    收藏  举报