题解:accoders::NOI 5519 虚树 / P12462 [Ynoi Easy Round 2018] 星野爱久爱海

2023 牛客 OI 赛前集训营-提高组(第一场)D

problem

有一棵 \(n\) 个节点,边权是正整数的树,和一个 \([1, n]\) 的排列 \(p\)。 有 \(Q\) 次询问,每次给出 \(l,r,k\),你需要回答在 \(p\) 的区间 \([l,r]\)\(k\) 个点,问这 \(k\) 个点构成的虚树的边权和最大和。

\(n\leq 200000,q\leq 10000,k\leq 100\)。强制在线。5s。

solution

考虑如果询问全局怎么做,那么就直接套用 题解 accoders::NOI 5511【漂亮轰炸(bomb)】 的做法,具体来说就是选出直径的一端为根,对树做带权的长链剖分,选出前 \(k-1\) 条长链。
那么对于区间的一个暴力是对区间中的点建虚树,跑上面的算法。

考虑优化,首先有一个结论是两个区间的答案是可以合并的,只需要取两个区间的答案中取到的长链的叶子和根,合并起来建虚树,再做上面的算法,因为其他的链一定取不到了。同时我们只保留前 \(k-1\) 条链。这样一次合并的复杂度可以认为是 \(O(k\log k)\)

所以我们可以使用线段树维护答案。这样复杂度大概是 \(O(nk\log k+qk\log k\log n)\) 比较爆炸。我们发现这个问题是可以 ST 表的,考虑 ST 表,但是 ST 表的预处理部分复杂度是 \(O(nk\log n\log k)\)。考虑的优化是将每 \(100\) 个点绑在一起进行分块,对 \(2000\) 个块建 ST 表,然后查询时暴力拿出散块,大力合并答案。\(O(\frac{nk\log k\log n}{k}+qk\log k)\) 已经可以通过本题。

code

// ubsan: undefined
// accoders
#include <cstdio>
#include <vector>
#include <cstring>
#include <cassert>
#include <algorithm>
#include <functional>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define debug(...) void(0)
#endif
typedef long long LL;
typedef long long i64;
int testop, n;
i64 lstans;
inline void decode(int &l, int &r, int &k, i64 lstans, int testop) {
    lstans %= 19260817;
    if (testop) {
        l ^= lstans;
        l = (l % n + n) % n + 1;
        r ^= lstans;
        r = (r % n + n) % n + 1;
        if (l > r)
            std ::swap(l, r);
        k ^= lstans;
        k = (k % std ::min(r - l + 1, 100)) + 1;
    }
}
template <int N, class T>
struct STable {
    T f[21][N + 10];
    int lg[N + 10], n;
    function<T(T, T)> func;
    STable(function<T(T, T)> func) : n(0), func(func) { lg[0] = -1; }
    int insert(const T &x) {
        f[0][++n] = x, lg[n] = lg[n >> 1] + 1;
        for (int j = 1; 1 << j <= n; j++) {
            int i = n - (1 << j) + 1;
            f[j][i] = func(f[j - 1][i], f[j - 1][i + (1 << (j - 1))]);
        }
        return n;
    }
    T query(int l, int r) {
        int k = lg[r - l + 1];
        return func(f[k][l], f[k][r - (1 << k) + 1]);
    }
};
typedef pair<int, int> pii;
int Q, gPos[1 << 18], dfn[1 << 18], per[1 << 18];
LL gDep[1 << 18];
vector<pair<int, int>> g[1 << 18];
STable<1 << 19, pii> Lca{ [](pii a, pii b) { return min(a, b); } };
void dfs(int u, int fa, int d) {
    static int cnt = 0;
    dfn[u] = ++cnt;
    gPos[u] = Lca.insert({ d, u });
    for (auto &&[v, w] : g[u])
        if (v != fa) {
            gDep[v] = gDep[u] + w;
            dfs(v, u, d + 1);
            Lca.insert({ d, u });
        }
}
int getLca(int u, int v) {
    if (gPos[u] > gPos[v])
        swap(u, v);
    return Lca.query(gPos[u], gPos[v]).second;
}
LL getDist(int u, int v) { return gDep[u] + gDep[v] - 2 * gDep[getLca(u, v)]; }
vector<pair<int, LL>> t[1 << 18];
void buildVTree(vector<int> h) {
    static int vis[1 << 18], tim = 0, stk[1 << 18];
    if (h.empty())
        return;
    ++tim;
    sort(h.begin(), h.end(), [&](int u, int v) { return dfn[u] < dfn[v]; });
    bool flag = 0;
    if (h[0] != 1)
        h.insert(h.begin(), 1);
    else
        flag = 1;
    h.erase(unique(h.begin(), h.end()), h.end());
    auto link = [&](int u, int v) {
        if (vis[u] < tim)
            vis[u] = tim, t[u].clear();
        if (vis[v] < tim)
            vis[v] = tim, t[v].clear();
        t[u].emplace_back(v, getDist(u, v));
        t[v].emplace_back(u, getDist(u, v));
    };
    int top = 0;
    stk[++top] = h[0];
    for (int i = 1; i < h.size(); i++) {
        int k = getLca(stk[top], h[i]);
        if (k != stk[top]) {
            while (top >= 2 && dfn[stk[top - 1]] > dfn[k]) link(stk[top - 1], stk[top]), --top;
            if (stk[top - 1] == k)
                link(stk[top], k), --top;
            else
                link(stk[top], k), stk[top] = k;
        }
        stk[++top] = h[i];
    }
    while (top >= 3) link(stk[top - 1], stk[top]), --top;
    if (top >= 2 && (flag || vis[1] == tim))
        link(stk[top - 1], stk[top]);
}
pair<vector<pair<LL, int>>, int> calcChain(vector<int> a) {
    if (a.empty())
        return {};
    static LL hei[1 << 18];
    static int down[1 << 18], son[1 << 18];
    buildVTree(a);
    function<pair<LL, int>(int, int, LL)> dfs1 = [&](int u, int fa, LL d) {
        pair<LL, int> ret = { d, u };
        for (auto &&[v, w] : t[u])
            if (v != fa) {
                ret = max(ret, dfs1(v, u, d + w));
            }
        return ret;
    };
    int root = dfs1(a[0], 0, 0).second;
    hei[0] = -1e18;
    vector<pair<LL, int>> vec = {};
    function<void(int, int)> dfs2 = [&](int u, int fa) {
        down[u] = u, hei[u] = son[u] = 0;
        for (auto &&[v, w] : t[u])
            if (v != fa) {
                dfs2(v, u);
                if (hei[v] + w > hei[u])
                    hei[u] = hei[son[u] = v] + w, down[u] = down[v];
            }
        for (auto &&[v, w] : t[u])
            if (v != fa) {
                if (v != son[u])
                    vec.emplace_back(hei[v] + w, down[v]);
            }
    };
    dfs2(root, 0);
    vec.emplace_back(hei[root], down[root]);
    sort(vec.begin(), vec.end(), greater<pair<LL, int>>{});
    vec.resize(min(int(vec.size()), 100));
    return { vec, root };
}
STable<2010, vector<int>> T{ [](vector<int> a, vector<int> b) {
    for (int x : b) a.push_back(x);
    auto [ret, root] = calcChain(a);
    vector<int> res = { root };
    for (int i = 0; i < min(int(ret.size()), 100); i++) res.push_back(ret[i].second);
    return res;
} };
int bel[1 << 18];
int main() {
#ifndef nfio
    freopen("nomenclature.in", "r", stdin);
    freopen("nomenclature.out", "w", stdout);
#endif
    scanf("%*d%d%d", &testop, &n);
    for (int i = 1, u, v, w; i < n; i++) {
        scanf("%d%d%d", &u, &v, &w);
        g[u].emplace_back(v, w);
        g[v].emplace_back(u, w);
    }
    for (int i = 1; i <= n; i++) scanf("%d", &per[i]);
    // for (int i = 1; i <= n; i++) bel[i] = (i - 1) / 100 + 1;
    // for (int i = 1; i <= n; i++) bel[i] = 1;
    dfs(1, 0, 1);
    for (int L = 1; L <= n; L += 100) {
        int R = min(L + 100 - 1, n);
        int b = T.insert(vector<int>(per + L, per + R + 1));
        for (int i = L; i <= R; i++) bel[i] = b;
    }
    scanf("%d", &Q);
    for (int L, R, k; Q--;) {
        scanf("%d%d%d", &L, &R, &k);
        decode(L, R, k, lstans, testop);
        // vector<pair<LL, int>> ret = calcChain(vector<int>(per + l, per + r + 1)).first;
        vector<int> small = {};
        int Lb = bel[L], Rb = bel[R];
        while (L <= R && bel[L] == Lb) small.push_back(per[L++]);
        while (L <= R && bel[R] == Rb) small.push_back(per[R--]);
        if (Lb + 1 <= Rb - 1) {
            vector<int> res = T.query(Lb + 1, Rb - 1);
            for (int x : res) small.push_back(x);
        }
        vector<pair<LL, int>> ret = calcChain(small).first;
        LL ans = 0;
        for (int i = 0; i < min(k - 1, int(ret.size())); i++) ans += ret[i].first;
        printf("%lld\n", lstans = ans);
    }
    return 0;
}

posted @ 2023-10-08 19:58  caijianhong  阅读(34)  评论(0)    收藏  举报