线段树优化 dp 学习笔记

前言

挺好玩的,感觉什么东西撞上线段树了就降智了。

你说得对,但不如直接上线段树维护

内容

首先得写出最简单的状态转移方程,观察方程有什么前缀后缀和,什么区间 max/min ,之类的东西,就可以考虑线段树维护了。
然后我们就把这些值直接扔线段树上就行。
具体看题。

例题 A:Linear Kingdom Races

定义 \(f_i\) 表示遍历到第 \(i\) 条路时获得的最大收益,则有状态转移方程如下:

\[f_i=\max(f_{i-1},\max_{1\le j \le i}(f_j+val_{j+1,i}-cost_{j+1,i})) \]

我们把后面的 \(\max\) 直接扔到线段树上,将修路按右端点离线,每次修到一个右端点就去统计。

点击查看代码
#include<iostream>
#include<vector>
#define ls u*2
#define rs u*2+1
#define int long long
using namespace std;
const int N=2e5+10;
int tr[4*N],tag[4*N];
int f[N];
struct node{
	int x,val;
};
vector<node> vec[N];
int n,m,p[N];
void pushup(int u){
	tr[u]=max(tr[ls],tr[rs]);
	return ;
}
void upd(int u,int k){
	tr[u]+=k;
	tag[u]+=k;
	return ;
}
void pushdown(int u){
	if(!tag[u]) return ;
	upd(ls,tag[u]);
	upd(rs,tag[u]);
	tag[u]=0;
	return ;
}
int query(int u,int l,int r,int x,int y){
	if(l>=x && r<=y){
		return tr[u];
	}
	pushdown(u);
	int mid=(l+r)>>1;
	int res=-1;
	if(x<=mid) res=max(res,query(ls,l,mid,x,y));
	if(mid<y) res=max(res,query(rs,mid+1,r,x,y));
	return res;
}
void modify(int u,int l,int r,int x,int y,int k){
	if(l>=x && r<=y){
		upd(u,k);
		return ;
	}
	pushdown(u);
	int mid=(l+r)>>1;
	if(x<=mid) modify(ls,l,mid,x,y,k);
	if(mid<y) modify(rs,mid+1,r,x,y,k);
	pushup(u);
	return ;
}
void mdf(int u,int l,int r,int x,int k){
	if(l==x && r==x){
		tag[u]=0;
		tr[u]=k;
		return ;
	}
	pushdown(u);
	int mid=(l+r)>>1;
	if(x<=mid) mdf(ls,l,mid,x,k);
	else mdf(rs,mid+1,r,x,k);
	pushup(u);
	return ;
}
signed main(){
	cin>>n>>m;
	for(int i=1;i<=n;i++) cin>>p[i];
	for(int i=1;i<=m;i++){
		int l,r,p;
		cin>>l>>r>>p;
		vec[r].push_back({l,p});
	}
	for(int i=1;i<=n;i++){
		modify(1,0,n,0,i-1,-p[i]);
		for(node a:vec[i]){
			int x=a.x,val=a.val;
			modify(1,0,n,0,x-1,val);
		}
		f[i]=max(f[i-1],query(1,0,n,0,i-1));
		mdf(1,0,n,i,f[i]);
	}
	cout<<f[n];
	return 0;
}

例题 B:LEQ and NEQ

大部分题解 单调栈+奇偶性滚一维 的操作太科幻了,我们使用线段树!

定义 \(f_{i,j}\) 表示第 \(i\) 个位置选 \(j\) 的方案数,有如下转移:

\[f_{i,j}=\sum_{k=1}^{a_i-1}f_{i-1,k}-f_{i-1,j} (j\le a_i) \]

\(j>a_i\) 显然为 0。
转移在线段树上就表现为:当前和 + \(1\to a_{i-1}\) 区间取反。对于 \(j>a_i\) 的情况我们区间赋 0 即可。

点击查看代码
#include<bits/stdc++.h>
#define ll long long

using namespace std;
const int N=5e5+10;
const int p=998244353;
int n,a[N];
struct Tree{
    int ls,rs;
    ll val,tag;
    bool f1,f2;//区间取反,区间清零
}tr[32*N];
int tot,rt,mx;
void pushup(int u){
    tr[u].val=(tr[tr[u].ls].val+tr[tr[u].rs].val)%p;
}
void init(int u){
    tr[u].val=tr[u].tag=0;
    tr[u].ls=tr[u].rs=0;
    tr[u].f1=0;
    tr[u].f2=1;
}
void upd(int u,int len,bool f,ll k){
    if(f){
        tr[u].val=-tr[u].val;
        tr[u].tag=-tr[u].tag;
        tr[u].f1^=1;
    }
    if(k){
        tr[u].val=(tr[u].val+k*len)%p;
        tr[u].tag=(tr[u].tag+k)%p;
    }
    return ;
}
void pushdown(int u,int l,int r){
    if(!tr[u].ls) tr[u].ls=++tot;
    if(!tr[u].rs) tr[u].rs=++tot;
    if(tr[u].f2){
        init(tr[u].ls);init(tr[u].rs);
        tr[u].f2=0;
    } 
    int mid=(l+r)>>1;
    upd(tr[u].ls,mid-l+1,tr[u].f1,tr[u].tag);
    upd(tr[u].rs,r-mid,tr[u].f1,tr[u].tag);
    tr[u].f1=0;tr[u].tag=0;
}
void mdf(int &u,int l,int r,int x,int y){
    if(!u) u=++tot;
    if(l>=x && r<=y){
        init(u);
        return ;
    }
    pushdown(u,l,r);
    int mid=(l+r)>>1;
    if(x<=mid) mdf(tr[u].ls,l,mid,x,y);
    if(mid<y) mdf(tr[u].rs,mid+1,r,x,y);
    pushup(u);
}
void modify(int &u,int l,int r,int x,int y,ll k,bool op){
    if(!u) u=++tot;
    if(l>=x && r<=y){
        upd(u,r-l+1,op,k);
        return ;
    }
    pushdown(u,l,r);
    int mid=(l+r)>>1;
    if(x<=mid) modify(tr[u].ls,l,mid,x,y,k,op);
    if(mid<y) modify(tr[u].rs,mid+1,r,x,y,k,op);
    pushup(u);
}
int main(){
#ifdef LOCAL
    freopen("D:/Desktop/cpp/data/code.in","r",stdin);
    freopen("D:/Desktop/cpp/data/code.out","w",stdout);
#endif
    cin>>n;
    for(int i=1;i<=n;i++){
        cin>>a[i];
        mx=max(mx,a[i]);
    }
    modify(rt,1,mx,1,a[1],1,0);
    for(int i=2;i<=n;i++){
        if(a[i]>=a[i-1]){
            ll k=tr[rt].val;
            modify(rt,1,mx,1,a[i-1],0,1);
            modify(rt,1,mx,1,a[i],k,0);
        }else{
            ll k=tr[rt].val;
            mdf(rt,1,mx,a[i]+1,a[i-1]);
            modify(rt,1,mx,1,a[i],0,1);
            modify(rt,1,mx,1,a[i],k,0);
        }
    }
    cout << (tr[1].val+p) %p;
    return 0;
}

例题 C:Minimax

似乎是 dbg 的线段树合并题。怎么黑了

定义 \(f_{i,j}\) 表示第 \(i\) 个节点出现 \(j\) 的概率,则转移方程如下:

\[f_{i,j}=f_{l,j}\times (p_i\times\sum_{k=1}^{j-1}f_{r,k}+(1-p_i)\times \sum_{k=j+1}^{m}f_{r,k})+f_{r,j}\times (p_i\times\sum_{k=1}^{j-1}f_{l,k}+(1-p_i)\times \sum_{k=j+1}^{m}f_{l,k}) \]

发现这些东西是个前缀和,后缀和的玩意,考虑扔到线段树上,转移就用线段树合并。
值域大需要离散化,合并的时候需要提前保存参数,其它的都是正常线段树操作。

点击查看代码
#include<bits/stdc++.h>
#define ll long long

using namespace std;
const int N=3e5+10;
const int p=998244353;
ll qpow(ll a,int b){
    ll res=1;
    while(b){
        if(b&1) res=res*a%p;
        a=a*a%p;
        b>>=1;
    }
    return res%p;
}
int n,cnt[N],ch[N][2];
ll val[N],tmp[N];
int sum;
ll s[N];
int ans=0;
struct Tree{
    int ls,rs;
    ll val,tag;
}tr[50*N];
int rt[N],tot=0;
int build(){
    tot++;
    tr[tot].ls=tr[tot].rs=tr[tot].val=0;
    tr[tot].tag=1;
    return tot;
}
void upd(int u,ll k){
    if(!u) return ;
    tr[u].val=1ll*tr[u].val*k%p;
    tr[u].tag=1ll*tr[u].tag*k%p;
}
void pushup(int u){
    tr[u].val=(tr[tr[u].ls].val+tr[tr[u].rs].val)%p;
}
void pushdown(int u){
    if(tr[u].tag==1) return ;
    if(tr[u].ls) upd(tr[u].ls,tr[u].tag);
    if(tr[u].rs) upd(tr[u].rs,tr[u].tag);
    tr[u].tag=1;
}
void modify(int &u,int l,int r,int x,ll k){
    if(!u) u=build();
    if(l==r){
        tr[u].val=k;
        return ;
    }
    pushdown(u);
    int mid=(l+r)>>1;
    if(x<=mid) modify(tr[u].ls,l,mid,x,k);
    else modify(tr[u].rs,mid+1,r,x,k);
    pushup(u);
}
int merge(int u,int v,int l,int r,ll umul,ll vmul,ll val){
    if(!u && !v) return 0;
    if(!u){
        upd(v,vmul);
        return v;
    }
    if(!v){
        upd(u,umul);
        return u;
    }
    pushdown(u);pushdown(v);
    int mid=(l+r)>>1;
    int uls=tr[tr[u].ls].val,urs=tr[tr[u].rs].val;
    int vls=tr[tr[v].ls].val,vrs=tr[tr[v].rs].val;
//    cout << tr[u].ls <<' '<<tr[u].rs << ' '<<tr[v].ls << ' '<<tr[v].rs << '\n';
    tr[u].ls=merge(tr[u].ls,tr[v].ls,l,mid,(umul+1ll*vrs%p*(1-val+p)%p)%p,(vmul+1ll*urs%p*(1-val+p)%p)%p,val);
    tr[u].rs=merge(tr[u].rs,tr[v].rs,mid+1,r,(umul+1ll*vls%p*val%p)%p,(vmul+1ll*uls%p*val%p),val);
//    cout << tr[u].ls<<' '<<tr[u].rs<<'\n';
    pushup(u);
    return u;
}
void query(int u,int l,int r){
    if(!u) return ;
    if(l==r){
        s[l]=tr[u].val;
        return ;
    }
    pushdown(u);
    int mid=(l+r)>>1;
    query(tr[u].ls,l,mid);
    query(tr[u].rs,mid+1,r);
}
void debug(int u,int l,int r){
    cout<<u<<' '<< tr[u].ls <<' '<<tr[u].rs<<'\n';
    if(!u) return ;
    if(l==r){
        return ;
    }
    int mid=(l+r)>>1;
    debug(tr[u].ls,l,mid);
    debug(tr[u].rs,mid+1,r);
}
void dfs(int u){
    if(!cnt[u]) modify(rt[u],1,sum,val[u],1);
    if(cnt[u]==1){
        dfs(ch[u][0]);rt[u]=rt[ch[u][0]];
    }
    if(cnt[u]==2){
        dfs(ch[u][0]);dfs(ch[u][1]);
        rt[u]=merge(rt[ch[u][0]],rt[ch[u][1]],1,sum,0,0,val[u]);
    }
//   cout << u << '\n';debug(rt[u],1,sum);
    return ;
}

int main(){
#ifdef LOCAL
freopen("D:/Desktop/cpp/data/code.in","r",stdin);
freopen("D:/Desktop/cpp/data/code.out","w",stdout);
#endif
    cin>>n;
    for(int i=1;i<=n;i++){
        int x;
        cin>>x;
        if(x==0) continue;
        if(ch[x][0]) ch[x][1]=i;
        else ch[x][0]=i;
        cnt[x]++;
    }
    ll inv=qpow(10000,p-2)%p;
    for(int i=1;i<=n;i++){
        cin>>val[i];
        if(cnt[i]){
            val[i]=1ll*val[i]*inv%p;
        }else{
            tmp[++sum]=val[i]; 
        }
    }
    sort(tmp+1,tmp+1+sum);
    for(int i=1;i<=n;i++){
        if(!cnt[i]) val[i]=lower_bound(tmp+1,tmp+1+sum,val[i])-tmp;
    }
    dfs(1);
    query(rt[1],1,sum);
    for(int i=1;i<=sum;i++){
        ans=(ans+1ll*i%p*s[i]%p*s[i]%p*tmp[i]%p)%p;
    }
    cout << ans;
    return 0;
}

这里是我写挂的几点原因:

  1. pushup 没取模。
  2. val 和 tag 用混。
  3. 传参没开 long long。
posted @ 2025-05-14 21:10  Tighnari  阅读(9)  评论(0)    收藏  举报