P5574 [CmdOI2019] 任务分配问题

题目描述

经典的分 \(k\) 段问题,要求求出分 \(k\) 段后使每段顺序对数量之和最小,求这个最小的值。

思路

首先,我们很好得出这种分段问题的状态转移方程即 $$dp_{i,j}=\min{dp_{k,j-1}+w(k+1,i)}$$ 其中 \(dp_{i,j}\) 表示选到前 \(i\) 个数,分了 \(j\) 段的最小费用,我们可以用 \(O(n^2k)\) 的时间复杂度来实现,显然超时,得分20pts

接着考虑优化,不难发现,\(w(i+1,j)+w(i,j+1) \ge w(i,j)+w(i+1,j+1)\),所以,该式子满足四边形不等式,即可使用决策单调性来优化该状态转移方程。

P4767邮局这道题,与该题状态转移方程相同,所以一上手想到使用四边形不等式中 \(opt(i,j-1) \ge opt(i,j) \ge opt(i+1,j)\) 的性质优化该题优化该题,然而,这种算法时间复杂度为 \(O(n(n+m))\),并不能通过该题,结果超时,得分40pts,邮局一题中 \(n\)\(m\) 上界相同,而该题 \(m\) 远小于 \(n\) 我们需要使用分治优化将时间复杂度降到 \(O(k \times n \log n)\) 级别才能通过。

我们解决了DP阶段的问题,接着需要解决的就是预处理 \(w\) 数组的问题了,由于我们无法接受 \(O(n^2)\) 的时间复杂度,所以我们要对其进行优化,因为数组大小的限制,预处理出 \(w\) 数组这条思路已经行不通了,我们考虑在DP过程中求一段的花费,注意到,在分治求解的过程中,当决策点每向右移动一位,我们的费用由 \(w(i,j)\) 变为 \(w(i,j+1)\) 过程中,我们增加的顺序对费用即为 \(i\)\(j\) 中小于 \(a_{j+1}\) 的数的个数,而这个我们很容易实现 \(log\) 级别的时间复杂度,所以,记录 \(tl\)\(tr\) 分别表示上个状态转以后已经处理完的左右端点,如果两状态位于分治中同一区间,则每次转移需要 \(O(\log n)\) 的时间复杂度,如果改变了区间,需要先跳到所求区间,设两区间分别为 \([L_i,R_i]\)\([L_j,R_j]\) 所需要跳的步数即为 \(R_j-L_i\),放在分治中即可粗略计算为每层跳 \(n\) 次,每次跳跃转移点同样需要 \(O(\log n)\),综上所述,时间复杂度为 \(O(k \times n \log^2 n)\) 可以通过该题。

求顺序对时若使用线段树实现,时间超限,得分60pts,使用常数更小的树状数组实现,成功通过该题。

另外,观察状态转移方程,每次转移只和 \(j-1\) 有关,所以我们可以压缩掉一维,只记录当前与上一行状态。

代码

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int M=25010;
int n,m;
int a[M],w[M];
int dp[M][2],tl=1,tr=0,sum=0;
int lowbit(int x) 
{
	return x&-x;
}
void add(int x,int w)
{
	for(;x<=n;x+=lowbit(x))
		a[x]+=w;
}
int query(int x)
{
	int ans=0;
	for(;x;x-=lowbit(x))
		ans+=a[x];
	return ans;
}
void ask(int le,int re)
{
    while(tr<re)
	{
		++tr;
		sum+=query(w[tr]);
		add(w[tr],1);
	}
    while(tl>le)
	{
		--tl;
		sum+=tr-tl-query(w[tl]);
		add(w[tl],1);
	}
    while(tr>re)
	{
		add(w[tr],-1);
		sum-=query(w[tr]);
		tr--;
	}
    while(tl<le)
	{
		add(w[tl],-1);
		sum-=tr-tl-query(w[tl]);
		tl++;
	}
}
void solve(int le,int re,int lt,int rt)
{
	int mid=(le+re)>>1;
	int k=mid;
	for(int i=lt;i<=min(rt,mid-1);i++)
	{
		ask(i+1,mid);
        if(dp[i][1]+sum<=dp[mid][0])
		{
			dp[mid][0]=dp[i][1]+sum;
			k=i;
		}
	}
	if(mid-1>=le) solve(le,mid-1,lt,k);
	if(mid+1<=re) solve(mid+1,re,k,rt);
}
int main()
{
	memset(dp,0x3f,sizeof(dp));
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++)
	{
		scanf("%d",&w[i]);
	}
	dp[0][0]=dp[0][1]=0;
	for(int i=1;i<=m;i++)
	{
		solve(1,n,0,n-1);
		for(int j=1;j<=n;j++)
		{
			dp[j][1]=dp[j][0];
		}
	}
	printf("%d\n",dp[n][0]);
    return 0;
}
posted @ 2025-08-20 23:29  Naoxiaoyu  阅读(27)  评论(0)    收藏  举报