Loading

Luogu11303 做题记录

dp 神题。 link

常规 \(\mathcal O(n^2)\) 是简单的,瓶颈在于需要同时维护当前的 \(l, r\) 值。

发现 dp 是不可避免的,我们是否可以分离 \(l, r\) 的计算贡献为两部分,以此获得优化的机会。

比如我们的到达序列为 \(k, \{k - 1, k - 2, \dots, l_1\}, \{k + 1, k + 2, \dots, r_1\}, \{l_1 - 1, l_1 - 2, \dots, l_2\}, \{r_1 + 1, r_1 + 2, \dots, r_2\}\)

  • 对于 \(i \in [l_1, k]\) 来说,代价就是 \(x_k - x_i\)

  • 对于 \(i \in [k + 1, r_1]\) 来说,代价就是 \(x_i - x_k + 2(x_k - x_{l_1})\)

  • 对于 \(i \in [l_2, l_1 - 1]\) 来说,代价就是 \(x_k - x_i + 2(x_{r_1} - x_{l_1})\)

  • 对于 \(i \in [r_1 + 1, r_2]\) 来说,代价就是 \(x_i - x_k + 2(x_k - x_{l_1}) + 2(x_{r_1} - x_{l_2})\)

代价 \(\sum\limits_{i = 1} ^ n |x_i - x_k|\) 是常规的,所以我们只关心额外代价。

\([l, r]\)\(l\) 处走到 \([l, r']\)\(r'\) 处,对于点 \(1 \sim l - 1\) 都产生了 \(2(x_r - x_l)\) 的额外代价。

同理,从 \(r\) 走到 \(l'\) 处,对点 \(r + 1 \sim n\) 都产生了 \(2(x_r - x_l)\) 的额外代价。

假定一开始往左走,那么我们相当于找两个序列 \(L, R\),满足 \(|L| - 1 \le |R| \le |L|\),其额外代价总和为:

\[2(n - k)(x_k - x_{L_1}) + \sum_{i = 2} ^ {|L|} 2(n - R_{i - 1})(x_{R_{i - 1}} - x_{L_i}) + \sum_{i = 1} ^ {|R|} 2(L_i - 1)(x_{R_i} - x_{L_i}) \]

注意到如果 \(L\) 不单调递减或 \(R\) 不单调递增是不优的,所以可以去掉这个限制。

这样就完全把 \(l, r\) 的状态分离了出来。

\(f_l, g_r\) 分别表示两个序列分别拼到了左边的 \(l\),或者右边的 \(r\) 时的答案。

转移即是从 \(l\) 拼到一个新的 \(r\),或者从 \(r\) 拼到一个新的 \(l\)

\[f_l = \min_{r\ge k} \{ g_r + 2(n - r)(x_r - x_l) \} \]

\[g_r = \min_{l\le k} \{ f_l + 2(l - 1)(x_r - x_l) \} \]

容易发现 \(f_k\le f_{k - 1}\le \dots \le f_1\)\(g_k\le g_{k + 1}\le \dots \le g_n\),可以仿照 dijkstra 的方式从 \(k\) 从小到大逐步向 \(1\)\(n\)。剩下就是斜率优化了,时间复杂度 \(\mathcal O(n)\)


这题关键在于:

  • 发现不用 dp 做不了

  • 尝试将 \(l, r\) 分离

注意我们不需要严格分离 \(l, r\) 的贡献,只需要分离其状态为两部分。

所以面对一些题目,不需要严格按照某个套路去做,多去思考,多去尝试。

手摸是个不错的方法,如果能带入具体数值就更有帮助,因为这样有利于观察模型,刻画模型,提取问题的本质。

不要认为某些方法不可做,万一推到了这一步,就存在一种明显的优化方法了呢?所以多去实践,才能发现更多更广的事物。


点击查看代码
#include <bits/stdc++.h>
#define ll long long
#define ull unsigned ll
#define fi first
#define se second
#define mkp make_pair
#define pir pair <ll, ll>
#define pb push_back
#define i128 __int128
using namespace std;
char buf[1 << 22], *p1, *p2;
// #define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, (1 << 22) - 10, stdin), p1 == p2)? EOF : *p1++)
template <class T>
const inline void rd(T &x) {
    char ch; bool neg = 0;
    while(!isdigit(ch = getchar()))
        if(ch == '-') neg = 1;
    x = ch - '0';
    while(isdigit(ch = getchar()))
        x = (x << 1) + (x << 3) + ch - '0';
    if(neg) x = -x;
}
const ll maxn = 3e5 + 10, inf = 1e18, mod = 1e9 + 7;
ll power(ll a, ll b = mod - 2) {
	ll s = 1;
	while(b) {
		if(b & 1) s = 1ll * s * a %mod;
		a = 1ll * a * a %mod, b >>= 1;
	} return s;
}
template <class T, class _T>
const inline ll pls(const T x, const _T y) { return x + y >= mod? x + y - mod : x + y; }
template <class T, class _T>
const inline void add(T &x, const _T y) { x = x + y >= mod? x + y - mod : x + y; }
template <class T, class _T>
const inline void chkmax(T &x, const _T y) { x = x < y? y : x; }
template <class T, class _T>
const inline void chkmin(T &x, const _T y) { x = x < y? x : y; }

ll n, k, a[maxn], f[maxn], g[maxn], L, R;
ll q1[maxn], l1, r1, q2[maxn], l2, r2, h1[maxn], h2[maxn];

int main() {
    rd(n), rd(k); L = R = k;
    for(ll i = 2; i <= n; i++) rd(a[i]), a[i] += a[i - 1];
    q1[l1 = r1 = 1] = q2[l2 = r2 = 1] = k;
    h1[k] = 2 * a[k] * (n - k), h2[k] = -2 * a[k] * (k - 1);
    while(L > 1 || R < n) {
        ll x = inf, y = inf;
        if(L > 1) {
            while(l1 < r1 && -2 * a[L - 1] *
             (q1[l1 + 1] - q1[l1]) >= h1[q1[l1 + 1]] - h1[q1[l1]]) ++l1;
            ll c = q1[l1];
            x = g[c] + 2 * (a[c] - a[L - 1]) * (n - c);
        }
        if(R < n) {
            while(l2 < r2 && -2 * a[R + 1] *
             (q2[l2 + 1] - q2[l2]) >= h2[q2[l2 + 1]] - h2[q2[l2]]) ++l2;
            ll c = q2[l2];
            y = f[c] + 2 * (a[R + 1] - a[c]) * (c - 1);
        }
        if(x <= y) {
            f[--L] = x, h2[L] = f[L] - 2 * a[L] * (L - 1);
            while(l2 < r2 && (q2[r2] - L) * (h2[q2[r2]] - h2[q2[r2 - 1]])
             >= (q2[r2 - 1] - q2[r2]) * (h2[L] - h2[q2[r2]])) --r2;
            q2[++r2] = L;
        } else {
            g[++R] = y, h1[R] = g[R] + 2 * a[R] * (n - R);
            while(l1 < r1 && (R - q1[r1]) * (h1[q1[r1]] - h1[q1[r1 - 1]])
             >= (q1[r1] - q1[r1 - 1]) * (h1[R] - h1[q1[r1]])) --r1;
            q1[++r1] = R;
        }
    } ll sum = 0;
    for(ll i = 1; i <= n; i++) sum += abs(a[i] - a[k]);
    printf("%lld\n", sum + min(f[1], g[n]));
    return 0;
}
posted @ 2025-04-14 08:59  Sktn0089  阅读(27)  评论(0)    收藏  举报