题解:CF2183H

题意

\[f(b) = m \times \sum^m_{i=1}b_i \]

给定一个序列 \(a\) ,要求将其划分为 \(k\) 个不交的非空子序列 \(s_1,s_2,...,s_k\),最小化 \(\sum f(s_i)\)

题解

STEP 1

很有意思的一道题,首先发现子序列比较难做,考虑来寻找一些可爱的性质。

首先可以想到将序列 \(a\) 从大到小排序,注意到我们划分出来的子序列一定是新序列的子串。

我们把划分出来的子串的长度分别设为 \(m_1,m_2,...,m_k\),还可以注意到 \(m_i \ge m_{i+1} , i \in [1,k)\)

也就是存在一种最优划分,使得所有子序列在序列 \(a\) 降序排序后成为连续子段,且子段长度非递增。

记原序列为 \(a\),将其降序排序得到 \(a'\)。设任意一种划分得到 \(k\) 个子序列 \(s_1, s_2, \dots, s_k\),每个 \(s_i\) 的元素集合为 \(S_i\)

定义 \(m_i = |S_i|\)\(sum_i = \sum_{x \in S_i} x\)

则总代价为 \(\sum_{i=1}^k m_i \cdot sum_i\)

由于 \(f\) 仅与子序列的长度和元素和有关,与顺序无关,我们可以将每个 \(S_i\) 中的元素按任意顺序排列。现将 \(S_i\) 按照其最大元素从大到小排序,不妨设 \(\max(S_1) \ge \max(S_2) \ge \dots \ge \max(S_k)\)

假设存在 \(i < j\) 使得 \(m_i < m_j\),即最大元素较大的子序列长度反而较小。考虑交换 \(S_i\) 中的最小元素 \(x\)\(S_j\) 中的最大元素 \(y\)。由于 \(\max(S_i) \ge \max(S_j)\)\(x \le y\)(因为 \(x\)\(S_i\) 中最小,\(y\)\(S_j\) 中最大),交换后 \(S_i\) 变为 \((S_i \setminus \{x\}) \cup \{y\}\)\(S_j\) 变为 \((S_j \setminus \{y\}) \cup \{x\}\)。长度不变,新的元素和分别为 \(sum_i - x + y\)\(sum_j - y + x\)。代价变化量为:

\[\Delta = m_i(y-x) + m_j(x-y) = (m_i - m_j)(y-x) \]

由于 \(m_i < m_j\)\(y > x\),故 \(\Delta < 0\),即总代价减少。

因此,通过有限次这样的交换,我们可以使长度 \(m_i\) 与最大元素同序(即 \(m_i\) 非递增),且每次交换不增加代价。

进一步,由于 \(a'\) 是降序的,我们可以将每个 \(S_i\) 调整为 \(a'\) 中的连续子段(因为交换不同子序列的元素不会破坏长度非递增性,且最终每个子序列中的元素在 \(a'\) 中连续)。

具体地,将 \(a'\) 从前向后扫描,依次分配长度为 \(m_1, m_2, \dots, m_k\) 的连续子段。

由于 \(m_i\) 非递增,且 \(a'\) 降序,这样分配得到的子序列满足最大元素较大的子序列长度较小。

因此,存在最优划分满足上述结构。

有了这两个结论,我们就可以有一个 \(O(n^2k)\) 的暴力 DP。

\(f_{i,j}\) 表示前 \(i\) 个数划分为恰好 \(j\) 段的最小权值和,转移为:

\[f_{i,j} = \min\limits_{k=1}^{i-1} f_{k,j-1} + (s_i-s_k) \times (i-k) \]

其中 \(s\) 表示序列 \(a\) 的前缀和数组。

STEP2

到了这里你会发现这道题的本质其实是一个限制个数的区间划分问题。

发现 \(f_{n,k}\) 关于 \(k\) 是一个凸函数,可以使用 WQS 二分消去划分次数限制,证明如下:

考虑 \(k-1\) 段和 \(k+1\) 段的最优划分,通过合并或分裂某些段,可以得到两个 \(k\) 段的划分,利用均值不等式即可证明凸性。

具体地,设 \(A\)\(k-1\) 段的最优划分,\(B\)\(k+1\) 段的最优划分。将 \(A\) 中的某一段分裂成两段,得到 \(k\) 段划分 \(A'\);将 \(B\) 中的某相邻两段合并成一段,得到 \(k\) 段划分 \(B'\)。则有:

\[f(k-1) + f(k+1) \ge \text{cost}(A') + \text{cost}(B') \ge 2F(k). \]

其中第一个不等式是因为 \(A\)\(B\) 是最优的,所以分裂或合并后的代价不小于最优 \(k\) 段代价;第二个不等式是因为 \(\text{cost}(A')\)\(\text{cost}(B')\) 均不小于 \(f(k)\)。故 \(f(k)\) 下凸。

这样我们就有了 \(O(n^2log C)\) 的进阶版暴力,还得再优化一下每次 check 的内层 DP。

STEP3

注意到一个新的性质,负数至多划分为一段,这也就是说 DP 过程中负数段不可能有决策。

若将某个负数与前面的正数段合并,设正数段原长为 \(m\),和为 \(s>0\),负数为 \(neg<0\)

合并后该段代价变化为 \((m+1)(s+neg) - m s = s + m \cdot neg + neg\)

而该负数若在最后一段中,其贡献已计入。

分析总代价变化可知,当 \(|neg|\) 较小时合并可能导致代价增加。

通过计算可证,将负数保留在最后一段不劣于将其合并到正数段。

这样的话,所有的负数决策点就可以删掉了(注意这里要特判正数数量小于 \(k\) 的情况)。

此时我们发现内层 DP 满足决策单调性,WQS 二分后,内层 DP 形如:

\[dp[i] = \min_{0 \le j < i} \left\{ dp[j] + (s_i - s_j) \cdot (i - j) \right\} + \text{penalty}, \]

\(w(j, i) = (s_i - s_j) \cdot (i - j)\),需证明 \(w\) 满足四边形不等式:

\[\forall a < b < c < d,\quad w(a, c) + w(b, d) \le w(a, d) + w(b, c). \]

代入展开:

\[(s_c - s_a)(c-a) + (s_d - s_b)(d-b) \le (s_d - s_a)(d-a) + (s_c - s_b)(c-b). \]

整理得:

\[(s_c - s_d)(a-b) + (s_a - s_b)(c-d) \ge 0. \]

由于序列 \(a\) 降序排序,故前缀和 \(s\) 的差分(即 \(a_i\))递减,从而 \(s\) 是凹函数(二阶差分非正)。

因此对于 \(a<b<c<d\),有:

\[s_b - s_a \ge s_c - s_b \ge s_d - s_c. \]

注意到 \(s_a - s_b = -(s_b - s_a)\)\(s_c - s_d = -(s_d - s_c)\),故 \(s_a - s_b \le s_b - s_c \le s_c - s_d\)(均为非正数,绝对值递增)。

又因为 \(a-b<0\)\(c-d<0\),所以 \((s_c - s_d)(a-b) > 0\)\((s_a - s_b)(c-d) > 0\),不等式成立。

从而四边形不等式成立,决策单调性得证。

实现就是外层 WQS 二分,内层二分队列 。

参考代码

const int N = 2e5+10;

#define int i128

inline void write(int x)
{
    if(x < 0)
    {
		putchar('-');
		write(-x);
        return;
	}
	if(x < 10)
    {
		putchar(x+'0');return;
	}
	write(x/10);
    putchar(x % 10 + '0');
}

inline void writeln(int x)
{
    write(x);
    putchar('\n');
}

int a[N],s[N],f[N],g[N],inf,pr[N],pos[N],n,k,i,l,r,mid,c;

void init() {inf = 1;rep(i,1,30) inf = inf * 10;}

int cal(int l,int r) {return (s[r] - s[l]) * (r - l);}

bool cmp(int p1,int p2,int pr) 
{
    if(f[p1] + cal(p1,pr) < f[p2] + cal(p2,pr)) return 1;
    if(f[p1] + cal(p1,pr) == f[p2] + cal(p2,pr) && g[p1] > g[p2]) return 1;
    return 0;
}

void ck(int n,int c,int w)
{
	int cl,cr,l,r,mid,p = 0;
	pr[0] = c; pos[0] = f[0] = g[0] = cl = 0; cr = 1;

	rep(i,1,c)
	{
		while(pr[cl] < i) cl++;
		f[i] = f[pos[cl]] + cal(pos[cl],i) + w; g[i] = g[pos[cl]]+1;
		if(!cmp(i,pos[cr-1],c)) continue;

		while(1)
		{
			l = i, r = c;
			while(r - l > 1)
			{
				mid = l + r >> 1;
				if(cmp(i,pos[cr-1],mid)) r = mid;
				else l = mid;
			}
			pr[cr-1] = l;
			if(cr - cl >= 2 && pr[cr-1] <= pr[cr-2]) cr--;
			else break;
		}
		pr[cr] = c; pos[cr++] = i;
	}
	rep(i,1,c) if(cmp(i,p,n)) p = i;
	f[n] = f[p] + cal(p,n) + w;
    g[n] = g[p] + 1;

    // std::cout << f[n] << ' ' << g[n] << '\n';
}

int wqs()
{
    l = -inf, r = inf;

    while(r - l > 1)
    {
        mid = l + (r - l) / 2;
        ck(n,c,mid);
        if(g[n] >= k) l = mid;
        else r = mid;
    }
    ck(n,c,l);
    
    return f[n] - k * l;
}

void solve()
{   
    init();
    n = read() , k = read(); rep(i,1,n) a[i] = read();

    std::sort(a+1,a+n+1,std::greater<int>());

    rep(i,1,n) s[i] = s[i-1] + a[i];
    c = 0; rep(i,1,n) c += a[i] > 0;

    ckmin(c,n-1);

    if(k > c)
    {
        int ans = 0;
        rep(i,1,k-1) ans += a[i];
        ans += (s[n] - s[k-1]) * (n - k + 1);
        writeln(ans);
        return ;
    }

    writeln(wqs());
}
posted @ 2026-01-09 20:36  Zheng_iii  阅读(45)  评论(0)    收藏  举报