题解:P7867 「EVOI-RD1」马戏团
思路分析
设 \(dp_{i,j}\) 为选择第 \(j\) 个至第 \(i\) 个连续舞台的最大收益。分类讨论一下。
- 对于 \(j<i\) 的情况,我们考虑 \(dp_{i,j}\) 比 \(dp_{i-1,j}\) 的答案多了什么。发现其实是多了所有演出中 \(r_k=i\) 且 \(j \le l_k\) 的收益,但是又多加固了第 \(i\) 个舞台,所以要减去 \(c_i\)。所以得到状态转移方程:
\[dp_{i,j}=dp_{i-1,j}-c_i+\sum^{m}_{k=1}[j \le l_k,r_k=i]v_k,j<i
\]
- 对于 \(j=i\) 的情况,我们发现上一个状态可以是任意 \(dp_{n,m}(1 \le n \le m < i)\),但我们可以发现,我们每次更新答案更新的就是这些东西,所以得到状态转移方程:
\[dp_{i,i}=ans-c_i+\sum^{m}_{k=1}[ l_k=r_k=i]v_k,j=i
\]
但我们会发现当我们从状态 \(dp_{i-1,j}\) 转移的时候,没有加上 \(j \le l_k < i,r_k=i\) 的 \(v_k\) 值。有没有感觉这个式子有点熟悉,其实这东西就是 \(dp_{i,j}\) 的转移方程,更新答案的时候会一起更新求最大值,所以可以不用管它。
首先先来优化空间,发现 \(dp_{i,j}\) 的转移只和 \(dp{i-1,j}\) 有关,所以可以滚动数组省去 \(i\) 这一维。
那如何优化时间呢?我们可以发现,对于每一个演出,它一定会更新所有满足 \(i=r_k,j \le l_k\) 的 \(dp_{i,j}\) 即 \(dp_j\)。所以当我们枚举 \(i\) 的时候,可以找到所有 \(r_k=i\) 的演出,然后将 \(1 \le j \le l_k\) 的 \(dp_j\) 都加上一个 \(v_k\)。
区间加操作,我们可以维护线段树,线段树的每个点表示每个 \(dp\) 值,然后此操作就可以变成 \(O(\log(n))\) 的时间复杂度了。
那如何找到所有的 \(r_k=i\) 的演出呢,我们可以使用双指针的思路来维护,这样均摊时间复杂度加起来是 \(O(n+m)\) 的,这样我们的时间复杂度就变为 \(O((n+m) \times \log(n))\) 的了。然后就可以愉快的切掉此题。
这,就是传说中的线段树优化 DP。
AC code
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define mid (l+r>>1)
const int N=1e6+10;
int n,m,ans;
int cost[N],tr[N<<2],lan[N<<2];
struct node{
int l,r,v;
}a[N];
inline int read(){
int t=0,f=1;
register char c=getchar();
while(c<'0'||c>'9') f=(c=='-')?(-1):(f),c=getchar();
while(c>='0'&&c<='9') t=(t<<3)+(t<<1)+(c^48),c=getchar();
return t*f;
}
bool cmp(node x,node y){return x.r<y.r;}
void pushup(int bian){tr[bian]=max(tr[bian<<1],tr[bian<<1|1]);}
void pushdown(int bian){
tr[bian<<1]+=lan[bian];
tr[bian<<1|1]+=lan[bian];
lan[bian<<1]+=lan[bian];
lan[bian<<1|1]+=lan[bian];
lan[bian]=0;
}
void update(int bian,int l,int r,int L,int R,int x){
if(L<=l&&R>=r){tr[bian]+=x,lan[bian]+=x;return;}
if(lan[bian]) pushdown(bian);
if(L<=mid) update(bian<<1,l,mid,L,R,x);
if(R>mid) update(bian<<1|1,mid+1,r,L,R,x);
pushup(bian);
}
signed main(){
n=read(),m=read();
for(int i=1;i<=n;i++) cost[i]=read();
for(int i=1;i<=m;i++) a[i].l=read(),a[i].r=read(),a[i].v=read();
sort(a+1,a+1+m,cmp);
for(int i=1,j=1;i<=n;i++){
update(1,1,n,1,i,-cost[i]);
update(1,1,n,i,i,ans);
while(a[j].r==i&&j<=m) update(1,1,n,1,a[j].l,a[j].v),j++;
ans=max(ans,tr[1]); //此时的 tr[1] 就是所有 dp 的最大值
}
cout<<ans;
return 0;
}

浙公网安备 33010602011771号