线段树优化 dp 学习笔记
前言
挺好玩的,感觉什么东西撞上线段树了就降智了。
你说得对,但不如直接上线段树维护
内容
首先得写出最简单的状态转移方程,观察方程有什么前缀后缀和,什么区间 max/min ,之类的东西,就可以考虑线段树维护了。
然后我们就把这些值直接扔线段树上就行。
具体看题。
例题 A:Linear Kingdom Races
定义 \(f_i\) 表示遍历到第 \(i\) 条路时获得的最大收益,则有状态转移方程如下:
\[f_i=\max(f_{i-1},\max_{1\le j \le i}(f_j+val_{j+1,i}-cost_{j+1,i}))
\]
我们把后面的 \(\max\) 直接扔到线段树上,将修路按右端点离线,每次修到一个右端点就去统计。
点击查看代码
#include<iostream>
#include<vector>
#define ls u*2
#define rs u*2+1
#define int long long
using namespace std;
const int N=2e5+10;
int tr[4*N],tag[4*N];
int f[N];
struct node{
int x,val;
};
vector<node> vec[N];
int n,m,p[N];
void pushup(int u){
tr[u]=max(tr[ls],tr[rs]);
return ;
}
void upd(int u,int k){
tr[u]+=k;
tag[u]+=k;
return ;
}
void pushdown(int u){
if(!tag[u]) return ;
upd(ls,tag[u]);
upd(rs,tag[u]);
tag[u]=0;
return ;
}
int query(int u,int l,int r,int x,int y){
if(l>=x && r<=y){
return tr[u];
}
pushdown(u);
int mid=(l+r)>>1;
int res=-1;
if(x<=mid) res=max(res,query(ls,l,mid,x,y));
if(mid<y) res=max(res,query(rs,mid+1,r,x,y));
return res;
}
void modify(int u,int l,int r,int x,int y,int k){
if(l>=x && r<=y){
upd(u,k);
return ;
}
pushdown(u);
int mid=(l+r)>>1;
if(x<=mid) modify(ls,l,mid,x,y,k);
if(mid<y) modify(rs,mid+1,r,x,y,k);
pushup(u);
return ;
}
void mdf(int u,int l,int r,int x,int k){
if(l==x && r==x){
tag[u]=0;
tr[u]=k;
return ;
}
pushdown(u);
int mid=(l+r)>>1;
if(x<=mid) mdf(ls,l,mid,x,k);
else mdf(rs,mid+1,r,x,k);
pushup(u);
return ;
}
signed main(){
cin>>n>>m;
for(int i=1;i<=n;i++) cin>>p[i];
for(int i=1;i<=m;i++){
int l,r,p;
cin>>l>>r>>p;
vec[r].push_back({l,p});
}
for(int i=1;i<=n;i++){
modify(1,0,n,0,i-1,-p[i]);
for(node a:vec[i]){
int x=a.x,val=a.val;
modify(1,0,n,0,x-1,val);
}
f[i]=max(f[i-1],query(1,0,n,0,i-1));
mdf(1,0,n,i,f[i]);
}
cout<<f[n];
return 0;
}
例题 B:LEQ and NEQ
大部分题解 单调栈+奇偶性滚一维 的操作太科幻了,我们使用线段树!
定义 \(f_{i,j}\) 表示第 \(i\) 个位置选 \(j\) 的方案数,有如下转移:
\[f_{i,j}=\sum_{k=1}^{a_i-1}f_{i-1,k}-f_{i-1,j} (j\le a_i)
\]
\(j>a_i\) 显然为 0。
转移在线段树上就表现为:当前和 + \(1\to a_{i-1}\) 区间取反。对于 \(j>a_i\) 的情况我们区间赋 0 即可。
点击查看代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=5e5+10;
const int p=998244353;
int n,a[N];
struct Tree{
int ls,rs;
ll val,tag;
bool f1,f2;//区间取反,区间清零
}tr[32*N];
int tot,rt,mx;
void pushup(int u){
tr[u].val=(tr[tr[u].ls].val+tr[tr[u].rs].val)%p;
}
void init(int u){
tr[u].val=tr[u].tag=0;
tr[u].ls=tr[u].rs=0;
tr[u].f1=0;
tr[u].f2=1;
}
void upd(int u,int len,bool f,ll k){
if(f){
tr[u].val=-tr[u].val;
tr[u].tag=-tr[u].tag;
tr[u].f1^=1;
}
if(k){
tr[u].val=(tr[u].val+k*len)%p;
tr[u].tag=(tr[u].tag+k)%p;
}
return ;
}
void pushdown(int u,int l,int r){
if(!tr[u].ls) tr[u].ls=++tot;
if(!tr[u].rs) tr[u].rs=++tot;
if(tr[u].f2){
init(tr[u].ls);init(tr[u].rs);
tr[u].f2=0;
}
int mid=(l+r)>>1;
upd(tr[u].ls,mid-l+1,tr[u].f1,tr[u].tag);
upd(tr[u].rs,r-mid,tr[u].f1,tr[u].tag);
tr[u].f1=0;tr[u].tag=0;
}
void mdf(int &u,int l,int r,int x,int y){
if(!u) u=++tot;
if(l>=x && r<=y){
init(u);
return ;
}
pushdown(u,l,r);
int mid=(l+r)>>1;
if(x<=mid) mdf(tr[u].ls,l,mid,x,y);
if(mid<y) mdf(tr[u].rs,mid+1,r,x,y);
pushup(u);
}
void modify(int &u,int l,int r,int x,int y,ll k,bool op){
if(!u) u=++tot;
if(l>=x && r<=y){
upd(u,r-l+1,op,k);
return ;
}
pushdown(u,l,r);
int mid=(l+r)>>1;
if(x<=mid) modify(tr[u].ls,l,mid,x,y,k,op);
if(mid<y) modify(tr[u].rs,mid+1,r,x,y,k,op);
pushup(u);
}
int main(){
#ifdef LOCAL
freopen("D:/Desktop/cpp/data/code.in","r",stdin);
freopen("D:/Desktop/cpp/data/code.out","w",stdout);
#endif
cin>>n;
for(int i=1;i<=n;i++){
cin>>a[i];
mx=max(mx,a[i]);
}
modify(rt,1,mx,1,a[1],1,0);
for(int i=2;i<=n;i++){
if(a[i]>=a[i-1]){
ll k=tr[rt].val;
modify(rt,1,mx,1,a[i-1],0,1);
modify(rt,1,mx,1,a[i],k,0);
}else{
ll k=tr[rt].val;
mdf(rt,1,mx,a[i]+1,a[i-1]);
modify(rt,1,mx,1,a[i],0,1);
modify(rt,1,mx,1,a[i],k,0);
}
}
cout << (tr[1].val+p) %p;
return 0;
}
例题 C:Minimax
似乎是 dbg 的线段树合并题。怎么黑了
定义 \(f_{i,j}\) 表示第 \(i\) 个节点出现 \(j\) 的概率,则转移方程如下:
\[f_{i,j}=f_{l,j}\times (p_i\times\sum_{k=1}^{j-1}f_{r,k}+(1-p_i)\times \sum_{k=j+1}^{m}f_{r,k})+f_{r,j}\times (p_i\times\sum_{k=1}^{j-1}f_{l,k}+(1-p_i)\times \sum_{k=j+1}^{m}f_{l,k})
\]
发现这些东西是个前缀和,后缀和的玩意,考虑扔到线段树上,转移就用线段树合并。
值域大需要离散化,合并的时候需要提前保存参数,其它的都是正常线段树操作。
点击查看代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=3e5+10;
const int p=998244353;
ll qpow(ll a,int b){
ll res=1;
while(b){
if(b&1) res=res*a%p;
a=a*a%p;
b>>=1;
}
return res%p;
}
int n,cnt[N],ch[N][2];
ll val[N],tmp[N];
int sum;
ll s[N];
int ans=0;
struct Tree{
int ls,rs;
ll val,tag;
}tr[50*N];
int rt[N],tot=0;
int build(){
tot++;
tr[tot].ls=tr[tot].rs=tr[tot].val=0;
tr[tot].tag=1;
return tot;
}
void upd(int u,ll k){
if(!u) return ;
tr[u].val=1ll*tr[u].val*k%p;
tr[u].tag=1ll*tr[u].tag*k%p;
}
void pushup(int u){
tr[u].val=(tr[tr[u].ls].val+tr[tr[u].rs].val)%p;
}
void pushdown(int u){
if(tr[u].tag==1) return ;
if(tr[u].ls) upd(tr[u].ls,tr[u].tag);
if(tr[u].rs) upd(tr[u].rs,tr[u].tag);
tr[u].tag=1;
}
void modify(int &u,int l,int r,int x,ll k){
if(!u) u=build();
if(l==r){
tr[u].val=k;
return ;
}
pushdown(u);
int mid=(l+r)>>1;
if(x<=mid) modify(tr[u].ls,l,mid,x,k);
else modify(tr[u].rs,mid+1,r,x,k);
pushup(u);
}
int merge(int u,int v,int l,int r,ll umul,ll vmul,ll val){
if(!u && !v) return 0;
if(!u){
upd(v,vmul);
return v;
}
if(!v){
upd(u,umul);
return u;
}
pushdown(u);pushdown(v);
int mid=(l+r)>>1;
int uls=tr[tr[u].ls].val,urs=tr[tr[u].rs].val;
int vls=tr[tr[v].ls].val,vrs=tr[tr[v].rs].val;
// cout << tr[u].ls <<' '<<tr[u].rs << ' '<<tr[v].ls << ' '<<tr[v].rs << '\n';
tr[u].ls=merge(tr[u].ls,tr[v].ls,l,mid,(umul+1ll*vrs%p*(1-val+p)%p)%p,(vmul+1ll*urs%p*(1-val+p)%p)%p,val);
tr[u].rs=merge(tr[u].rs,tr[v].rs,mid+1,r,(umul+1ll*vls%p*val%p)%p,(vmul+1ll*uls%p*val%p),val);
// cout << tr[u].ls<<' '<<tr[u].rs<<'\n';
pushup(u);
return u;
}
void query(int u,int l,int r){
if(!u) return ;
if(l==r){
s[l]=tr[u].val;
return ;
}
pushdown(u);
int mid=(l+r)>>1;
query(tr[u].ls,l,mid);
query(tr[u].rs,mid+1,r);
}
void debug(int u,int l,int r){
cout<<u<<' '<< tr[u].ls <<' '<<tr[u].rs<<'\n';
if(!u) return ;
if(l==r){
return ;
}
int mid=(l+r)>>1;
debug(tr[u].ls,l,mid);
debug(tr[u].rs,mid+1,r);
}
void dfs(int u){
if(!cnt[u]) modify(rt[u],1,sum,val[u],1);
if(cnt[u]==1){
dfs(ch[u][0]);rt[u]=rt[ch[u][0]];
}
if(cnt[u]==2){
dfs(ch[u][0]);dfs(ch[u][1]);
rt[u]=merge(rt[ch[u][0]],rt[ch[u][1]],1,sum,0,0,val[u]);
}
// cout << u << '\n';debug(rt[u],1,sum);
return ;
}
int main(){
#ifdef LOCAL
freopen("D:/Desktop/cpp/data/code.in","r",stdin);
freopen("D:/Desktop/cpp/data/code.out","w",stdout);
#endif
cin>>n;
for(int i=1;i<=n;i++){
int x;
cin>>x;
if(x==0) continue;
if(ch[x][0]) ch[x][1]=i;
else ch[x][0]=i;
cnt[x]++;
}
ll inv=qpow(10000,p-2)%p;
for(int i=1;i<=n;i++){
cin>>val[i];
if(cnt[i]){
val[i]=1ll*val[i]*inv%p;
}else{
tmp[++sum]=val[i];
}
}
sort(tmp+1,tmp+1+sum);
for(int i=1;i<=n;i++){
if(!cnt[i]) val[i]=lower_bound(tmp+1,tmp+1+sum,val[i])-tmp;
}
dfs(1);
query(rt[1],1,sum);
for(int i=1;i<=sum;i++){
ans=(ans+1ll*i%p*s[i]%p*s[i]%p*tmp[i]%p)%p;
}
cout << ans;
return 0;
}
这里是我写挂的几点原因:
- pushup 没取模。
- val 和 tag 用混。
- 传参没开 long long。