Loading

[集训队互测 2025] 火花 做题记录

link

有点牛逼啊这题。

考虑你取路径和取含根连通块,这两部分的没有很明显的贪心关系,只能将它们同时 DP 决策。

对于 \(t = 0\) 的 subtask 有个很明显的做法,就是在 dfs 序上 dp,多重背包部分用单调队列优化。这告诉我们正解一定是在 dfs 序上做 dp 的,或者用一些更高级的做法。

如果要 dp,现在的问题在于如何维护一个点子树内选了多少条路径,因为这会影响这个点物品数量。

这题最厉害的地方在于这里:感受 \(t < c_u\) 这个条件的充实之处,在 dfs 序上扫描时记录 \(x\) 表示后续选择的路径数量。当到达一个点 \(u\) 时,先加入 \(c_u - x\) 个物品,这 \(c_u - x\) 个是一定保留的。在 \(u\) 子树的结束位置处,再加入 \(x\) 个物品,注意 \(x\) 是实时变化的。

会发现这样就满足条件了,这和差分的思有很大的相似之处,在此基础上进一步升华。

\(u\) 子树内选若干条路径是一个凸函数,所以可以决策单调性优化,时间复杂度 \(\mathcal O(nkt\log t)\)


  • 感受题目,感受条件,从千万种可能中抓住唯一的答案。

  • 这里的差分思想很厉害,可以多借鉴一下。


点击查看代码
#include <bits/stdc++.h>
#define ll int
#define LL long long
#define ull unsigned
#define uLL unsigned LL
#define fi first
#define se second
#define mkp make_pair
#define pir pair<ll, ll>
#define pb push_back
#define i128 __int128
using namespace std;
template <class T>
const inline void rd(T &x) {
    char ch;
    bool neg = 0;
    while (!isdigit(ch = getchar()))
        if (ch == '-')
            neg = 1;
    x = ch - '0';
    while (isdigit(ch = getchar())) x = (x << 1) + (x << 3) + ch - '0';
    if (neg)
        x = -x;
}
const ll maxn = 2e4 + 10, mod = 998244353, M = 1e6;
const LL inf = 1e18 + 5;
ll power(ll a, ll b = mod - 2, ll p = mod) {
    ll s = 1;
    while (b) {
        if (b & 1)
            s = 1ll * s * a % p;
        a = 1ll * a * a % p, b >>= 1;
    }
    return s;
}
template <class T, class _T>
const inline ll pls(const T x, const _T y) { return x + y >= mod ? x + y - mod : x + y; }
template <class T, class _T>
const inline ll mus(const T x, const _T y) { return x < y ? x + mod - y : x - y; }
template <class T, class _T>
const inline void add(T &x, const _T y) { x = x + y >= mod ? x + y - mod : x + y; }
template <class T, class _T>
const inline void sub(T &x, const _T y) { x = x < y ? x + mod - y : x - y; }
template <class T, class _T>
const inline void chkmax(T &x, const _T y) { x = x < y ? y : x; }
template <class T, class _T>
const inline void chkmin(T &x, const _T y) { x = x < y ? x : y; }

ll n, k, t, c[maxn], w[maxn], idfn[maxn], dfn[maxn], out[maxn], ti, siz[maxn];
vector <vector <LL> > f[maxn];
vector <LL> d[maxn], g[maxn];
vector <ll> to[maxn];
ll q[maxn], l, r;
LL a[maxn], b[maxn], e[maxn], h[maxn], sum[maxn];

void dfs(ll u) {
    idfn[dfn[u] = ++ti] = u, sum[u] += w[u];
    d[u].resize(t + 1), siz[u] = 1;
    for(ll i = 1; i <= t; i++) d[u][i] = -inf;
    if(t) d[u][1] = 0;
    for(ll v: to[u]) {
        sum[v] = sum[u], dfs(v), siz[u] += siz[v];
        ll p = 0, q = 0;
        for(ll i = 1; i <= t; i++) h[i] = -inf;
        while(p + q < t) {
            if(d[u][p + 1] >= d[v][q + 1])
                ++p, h[p + q] = d[u][p];
            else ++q, h[p + q] = d[v][q];
        }
        for(ll i = 1; i <= t; i++) d[u][i] = h[i];
    }
    for(ll i = 1; i <= t; i++)
        if(d[u][i] >= 0) d[u][i] += w[u];
    idfn[out[u] = ++ti] = u;
}

void solve(ll l, ll r, ll jl, ll jr) {
    if(l > r) return;
    ll mid = l + r >> 1, j = 0; e[mid] = -inf;
    for(ll i = max(jl, mid); i <= jr; i++) {
        LL tmp = a[i] + b[i - mid];
        if(tmp > e[mid]) e[mid] = tmp, j = i;
    }
    solve(l, mid - 1, jl, j);
    solve(mid + 1, r, j, jr);
}

int main() {
    rd(n), rd(k), rd(t);
    for(ll i = 1; i <= n; i++) rd(c[i]), rd(w[i]);
    for(ll i = 2, x; i <= n; i++) rd(x), to[x].pb(i);
    dfs(1);
    for(ll i = 1; i <= n; i++)
        for(ll j = 1; j <= min(t, siz[i]); j++)
            d[i][j] += d[i][j - 1];
    for(ll i = 1; i <= 2 * n + 1; i++) {
        f[i].resize(k + 1);
        for(ll j = 0; j <= k; j++) {
            f[i][j].resize(t + 1);
            for(ll x = 0; x <= t; x++) f[i][j][x] = -inf;
        }
    }
    f[1][0][t] = 0;
    for(ll i = 0; i <= k; i++) g[i].resize(t + 1);
    for(ll i = 1; i <= 2 * n; i++) {
        ll u = idfn[i];
        if(dfn[u] == i) {
            for(ll x = 0; x <= t; x++) {
                ll p = c[u] - x; l = 1, r = 0;
                for(ll j = 0; j <= k; j++) {
                    while(l <= r && j - q[l] > p) ++l;
                    if(j) {
                        chkmax(f[i + 1][j][x], f[i][q[l]][x] + 1ll * (j - q[l]) * w[u]);
                        if(x) chkmax(f[i + 1][j][x - 1], f[i][q[l]][x]
                         + 1ll * (j - q[l]) * w[u] + sum[u]);
                    }
                    while(l <= r && f[i][q[r]][x] - 1ll * q[r]
                     * w[u] <= f[i][j][x] - 1ll * j * w[u]) --r;
                    q[++r] = j;
                    g[j][x] = f[i][q[l]][x] + 1ll * (j - q[l]) * w[u];
                }
            }
            for(ll j = 0; j <= k; j++) {
                for(ll x = 0; x <= t; x++)
                    a[x] = g[j][x], b[x] = d[u][x] + (sum[u] - w[u]) * x;
                ll p = 0;
                while(p <= t && a[p] < 0) ++p;
                solve(max(0, p - siz[u]), t, p, t);
                for(ll x = 0; x < p - siz[u]; x++) e[x] = -inf;
                for(ll x = 0; x <= t; x++) g[j][x] = e[x];
            }
            ll z = out[u] + 1;
            for(ll x = 0; x <= t; x++) {
                l = 1, r = 0;
                for(ll j = 0; j <= k; j++) {
                    while(l <= r && g[q[r]][x] - 1ll * q[r]
                     * w[u] <= g[j][x] - 1ll * j * w[u]) --r;
                    q[++r] = j;
                    while(l <= r && j - q[l] > x) ++l;
                    chkmax(f[z][j][x], g[q[l]][x] + 1ll * (j - q[l]) * w[u]);
                }
            }
        } else {
            for(ll x = 0; x <= t; x++) {
                l = 1, r = 0;
                for(ll j = 0; j <= k; j++) {
                    while(l <= r && f[i][q[r]][x] - 1ll * q[r]
                     * w[u] <= f[i][j][x] - 1ll * j * w[u]) --r;
                    q[++r] = j;
                    while(l <= r && j - q[l] > x) ++l;
                    chkmax(f[i + 1][j][x], f[i][q[l]][x] + 1ll * (j - q[l]) * w[u]);
                }
            }
        }
    }
    LL ans = 0;
    for(ll i = 0; i <= k; i++) chkmax(ans, f[2 * n + 1][i][0]);
    printf("%lld\n", ans);
	return 0;
}
posted @ 2025-11-12 19:33  Sktn0089  阅读(47)  评论(0)    收藏  举报