动态规划的决策单调性及其分治解法

有同学催更,于是更新了(

决策单调性优化的实现有单调队列,斜率优化,但是我一想起才学单调队列优化的时候一调调一天的不好回忆就不是很想碰它(( 斜率优化暂时没深入学

但是这两个东西可以做到比分治优化更优的时间复杂度。

但是 可爱的xde 师者言还有基于分治的单调性优化,我看了看发现这也太深刻了,于是科研了一下午和一晚上。


考虑一类经典的 DP 问题:

你可以把给定的序列划分成任意段,每一段可以根据一些奇怪的式子算出贡献来,接下来问你所有分法的贡献之和的最大值。

平凡的设 \(dp_i\) 表示考虑了前 \(i\) 个元素,最后以 \(i\) 结尾(因为你不能舍弃一些元素)的最大贡献。接下来转移就是 \(dp_i = \max_{j=1}^{i-1} dp_j+w(j+1,i)\),其中 \(w(j+1,i)\) 表示段 \([j+1,i]\) 的贡献。最后的答案在 \(dp_n\) 处取到。

这个方程十分容易理解,但是这样的转移是 \(O(n^2)\) 的,你会发现某些题是过不了的。

此时我们有这种类型 DP 比较套路的优化方法:猜决策单调性。

决策单调性

考虑上面的转移:

\[dp_i = \max_{j=1}^{i-1} dp_j+w(j+1,i) \]

\(dp_i\) 的最优值一定是从 \([1,i-1]\) 中间的某个 \(j\) 转移过来的,我们称这样的 \(j\)\(i\) 的决策点,令其为 \(op_i\),即 \(j=op_i\)。若决策点有多个,取最左边的那一个。

决策单调性是这样的性质:\(\forall i<j\),都有 \(op_i<op_j\)

有了这个性质,我们就可以通过某些技巧极大的缩小转移区间,以此达到优化时间复杂度的目的。

判断决策单调性

其实考场上可以直接猜,无需严格证明的,因此这个文章里对于证明的部分不会细讲(其实是不会(

存在决策单调性的一个充分不必要条件是上面的贡献函数满足四边形不等式。

四边形不等式的定义是:对于函数 \(w(i,j)\),若对于所有在其定义域内的四元数组 \(a\le b\le c\le d\)\(w\) 函数均满足:

\[w(a,c)+w(b,d) \le w(a,d)+w(b,c) \]

则称函数 \(w\) 满足四边形不等式。可以简记为 ”相交小于包含“

若等号反向,也就是:

\[w(a,c)+w(b,d) \ge w(a,d)+w(b,c) \]

则称函数 \(w\) 满足反向四边形不等式。可以简记为 ”相交大 于包含“。

正向的四边形不等式可以用于优化取 \(\min\) 的 DP 转移,反向的可以优化取 \(\max\) 的。

若函数 \(w\) 满足任意一种四边形不等式,则形如上文的 DP 转移方程存在决策单调性。

你问我为什么?我不道啊(( 证明可以看 OI Wiki 的相关页面

我们还有另一个更弱的条件证明四边形不等式:对于函数 \(w(i,j)\),若对于所有在其定义域内的四元数组 \(a\le b\)\(w\) 函数均满足:

\[w(a,c+1)+w(a+1,c) \le w(a,c)+w(a+1,c+1) \]

则函数 \(w\) 满足四边形不等式。

决策单调性优化的分治实现

若求解函数 \(w\) 的最优时间复杂度为 \(O(1)\),则我们可通过决策单调性优化将转移过程优化到 \(O(n \times \mathrm{polylog}(n))\)

考虑这样的分治过程:我们定义一个函数 \(\operatorname{solve}(l,r,L,R)\),表示我们将要分治处理区间 \([l,r]\),该区间的所有决策点都可以确定在 \([L,R]\) 内。

下一步,我们找到区间中点 \(mid\),在 \([L,R]\) 中暴力找到 \(mid\) 的决策点,设其为 \(t\)。在此过程中,更新 \(dp_{mid}\) 的值。

由于我们上文提到的决策单调性,此时区间 \([l,mid-1]\) 的最优决策点一定在 \([L,t]\) 中,区间 \([mid+1,r]\) 的最优决策点一定在 \([t,R]\) 内部。

于是接下来我们递归求解 \(\operatorname{solve}(l,mid-1,L,t)\)\(\operatorname{solve}(mid+1,r,t,R)\)。问题成功变成了子问题。

初始时,我们调用 \(\operatorname{solve}(1,n,1,n)\)

通过这样的分治,我们可以在 \(O(n\log n)\) 的时间复杂度下完成对 \(dp\) 的一轮更新。

带区间个数限制的 DP

我们把问题稍作加强:限制你最多(或者必须正好)把原数组分成 \(k\) 个段。

还是先考虑朴素的 DP,我们给 DP 数组升维:\(dp_{i,k}\) 表示考虑了前 \(i\) 个数,强制在 \(i\) 处结尾并把 \([1,i]\) 分成了恰好 \(k\) 段的最优答案。

接下来我们有两种转移方法,一是先枚举 \(i\) 及其转移点 \(j\),对于所有转移点,枚举 \(k\) 并进行更新;二是先枚举 \(k\),在每一层中枚举 枚举 \(i\) 及其转移点 \(j\) 并进行更新。不论何种更新顺序,时间复杂度都是 \(O(n^2k)\) 的。

考虑优化,由于我们实现的 \(\operatorname{solve}\) 函数一次就会更新整个 dp 数组,于是我们采用第二种方法更新,也就是调用 \(k\)\(\operatorname{solve}\)

但是这里要注意,每次调用的 \(l,r,L,R\) 都有所不同,对于第 \(k\) 次调用,应为 \(\operatorname{solve}(k,n,k-1,n)\)。最根本的原因是尝试把前 \(k-1\) 个数分成 \(k\) 份是不可能的,此时不能更新第 \(k-1\) 位,同理也不能从第 \(k-2\) 位置转移过来。

递归完成每一层后,取 \(dp_n\) 处的答案,对于每一层取最大值作为最终的答案。

上述算法的时间复杂度为 \(O(nk\log n)\)

这里来个例题:

P4360 [CEOI 2004] 锯木厂选址

翻了翻各路题解好像都没有分治版本的代码?于是这里写了一份。

这个题即为上述问题的 \(k=3\) 版本。

code

Show me the code
#define psb push_back
#define mkp make_pair
#define ls p<<1
#define rs (p<<1)+1
#define rep(i,a,b) for( int i=(a); i<=(b); ++i)
#define per(i,a,b) for( int i=(a); i>=(b); --i)
#define rd read()
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
ll read(){
  ll x=0,f=1;
  char c=getchar();
  while(c>'9'||c<'0'){if(c=='-') f=-1;c=getchar();}
  while(c>='0'&&c<='9'){x=(x<<3)+(x<<1)+(c^48);c=getchar();}
  return x*f;
}
const int N=30000;
struct tree{
  int dis;int w;
}t[N];
int S[N],D[N],W[N];
int cost(int l,int r){
  return S[l]-S[r+1]-(W[l]-W[r+1])*D[r];
}
ll dp[N],lst[N];
void solve(int l,int r,int L,int R){
  int mid=(l+r)>>1,trans=0;
  for(int i=L;i<=min(R,mid-1);i++){
    if(dp[mid]>lst[i]+cost(i+1,mid)){
      dp[mid]=lst[i]+cost(i+1,mid);
      trans=i;
    }
  }
  if(l==r)return ;
  solve(l,mid,L,trans);
  solve(mid+1,r,trans,R);
  return ;
}
int main(){
  
  int n;cin>>n;
  for(int i=1;i<=n;i++){
    cin>>t[i].w>>t[i].dis;
  }
  for(int i=n;i>=1;i--){
    D[i]=D[i+1]+t[i].dis;
    W[i]=W[i+1]+t[i].w;
    S[i]=S[i+1]+D[i]*t[i].w;
  }
  for(int j=1;j<=n;j++)lst[j]=dp[j],dp[j]=INT_MAX;
  for(int i=1;i<=n;i++){lst[i]=cost(1,i);}
  solve(2,n,1,n);
  ll ans=INT_MAX;
  for(int i=2;i<=n;i++){
    ans=min(ans,dp[i]+cost(i+1,n+1));
  }
  cout<<ans;

  return 0;
}

不限制区间个数的 DP

这种分治使用了 CDQ 分治的思想。因为此时我们必须顺序更新 DP 数组,不能分层更新了。

我们在原来的分治基础上再套一层分治,定义函数 \(\operatorname{cdq}(l,r)\) 表示要处理并正确更新所有 \(dp_l \sim dp_r\)

接下来,找到区间中点 \(mid\),调用 \(\operatorname{cdq}(l,mid)\)

现在,\(dp_l \sim dp_{mid}\) 均已被正确更新,接下来调用上文的 \(\operatorname{solve}(mid+1,r,l,mid)\)

这次调用的目的是用 \([l,mid]\) 内的所有 dp 更新 \([mid,r]\) 之内的所有 dp。

你可能问这也不对啊,你怎么保证 \(dp_r\) 的决策点就在 \([l,mid]\) 之间?

事实上,这里牵扯到一个分治过程。首先考虑 \([1,l-1]\) 的区间,你会发现由于我们先向左边递归,因此递归到 \([l,r]\) 的时候,上面一定有一个大区间的左边给右边更新一次了,因此左边的情况就不用再考虑了。

接下来考虑 \([mid+1,r]\)。在调用完 \(\operatorname{solve}(mid+1,r,l,mid)\) 之后,你会发现至少 \(dp_{mid+1}\) 的值一定是正确的。

而调用完 \(\operatorname{solve}(mid+1,r,l,mid)\) 之后,我们还要紧接着调用一次 \(\operatorname{cdq}(mid+1,r)\)

当这次 \(\operatorname{cdq}\) 递归至区间 \([mid+1,mid+2]\) 并完成时,\(mid+2\) 的值也会被正确更新(\([1,mid]\) 的情况不用再考虑,\(mid+1\) 的值是对的)。

此时返回到区间 \([mid+1,mid+4]\)\(mid+3\) 的值也会被更新,接下来是 \(mid+4,mid+5,mid+6 \cdots\)

这样下来,所有 dp 值实际上都是被按顺序且正确更新的。

太深刻了兄弟,上述算法的时间复杂度为 \(O(n \log^2 n)\)

也给一道例题:

P3628 [APIO2010] 特别行动队

是板子啦。但是可以被单调队列 \(O(n)\) 包菜(

但是它好写啊而且能过就行(

code

Show me the code
#define psb push_back
#define mkp make_pair
#define ls p<<1
#define rs (p<<1)+1
#define rep(i,a,b) for( int i=(a); i<=(b); ++i)
#define per(i,a,b) for( int i=(a); i>=(b); --i)
#define rd read()
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
ll read(){
  ll x=0,f=1;
  char c=getchar();
  while(c>'9'||c<'0'){if(c=='-') f=-1;c=getchar();}
  while(c>='0'&&c<='9'){x=(x<<3)+(x<<1)+(c^48);c=getchar();}
  return x*f;
}
const int N=1e6+100;
#define int long long
int a,b,c;
int x[N],fx[N];
int dp[N];
int w(int l,int r){
  int ci=fx[r]-fx[l-1];
  return ci*ci*a+ci*b+c;
}
void solve(int l,int r,int L,int R){
  if(l>r)return ;
  int mid=l+r>>1; 
  int t=0;
  int tc=LONG_LONG_MIN;
  for(int i=L;i<=R&&i<mid;i++){
    if(dp[i]+w(i+1,mid)>tc){
      tc=dp[i]+w(i+1,mid);
      t=i;
    }
  }
  dp[mid]=max(dp[mid],tc);
  solve(l,mid-1,L,t);
  solve(mid+1,r,t,R);
  return ;
}
void cdq(int l,int r){
  if(l==r)return ;
  int mid=l+r>>1;
  cdq(l,mid);
  solve(mid+1,r,l,mid);
  cdq(mid+1,r);
  return ;
}
signed main(){
  
  int n;cin>>n;
  cin>>a>>b>>c;
  for(int i=1;i<=n;i++){
    cin>>x[i];
    fx[i]=fx[i-1]+x[i];
  }
  memset(dp,-0x3f,sizeof dp);
  dp[0]=0;
  cdq(0,n);
  cout<<dp[n];

  return 0;
}

更多例题可能会开到好题选讲里面(

end here

posted @ 2025-07-22 21:09  hm2ns  阅读(88)  评论(0)    收藏  举报