28 S2模拟赛T2 开会council 题解

council

题面

给定一棵 \(n\) 个节点的树,每个节点有黑白两种颜色,还有 \(k\) 个特殊节点。

距离表示两个点间路径上边权的最大值。

我们每次指定一个白点,对于每个黑点,设 \(disb\) 表示其到任意一个特殊点距离的最小值,设 \(disa\) 表示此特殊点到指定白点的距离。

这个白点对答案的贡献即为 \(\sum \max (disa, disb)\)

但是有一个点的颜色不太稳定,可能由黑变白,也可能由白变黑。

你的任务是求出没有点变色的答案,以及对于每个点,其变色并且其余点不变色的答案。

\(1 \le n \le 2\times 10^5\)

题解

这道题条件很多,我们需要谨慎处理。

我们先不考虑 \(disb\) 的影响,假设每个点本身就是一个特殊点,也就是只有 \(disa\),是比较好做的。

因为每个点可能变色,所以我们要考虑一个黑点变成白点的情况。

对于每个点我们记 \(f(i,0/1)\) 表示所有白/黑点到 \(i\) 的距离和,朴素的做法是 \(O(n^2)\) 暴力跑出来。

考虑如何优化这个东西

因为我们每次都是求路径上的边权最大值,所以我们可以将树重构。

也就是从小到大枚举每条边 \((x,y,z)\),然后将 \(x,y\) 所属的连通块尝试连接起来,如果两个连通块不连通,那么 \(z\) 就会成为两个连通块的点之间的最大权值。

对每个连通块,我们记 \(siz_{fx, 0}, siz_{fx, 1}\) 分别表示白点和黑点数量,以及 \(tag_{fx,0}, tag_{fx, 1}\) 表示对连通块中的点的白黑贡献。

我们每次合并两个连通块的时候按秩合并,也就是将小的合并到大的里边。对于大块中的点,我们直接打标记,对于小块中的点,我们将其标记下放并插入到大块中。

对于每个点来说至多合并 \(O(\log n)\) 次,所以合并的时间复杂度 \(O(n \log n)\)

然后我们来处理到特殊点距离的最小值。

在这之前,我们要说明一个结论,假设起点为 \(x\),特殊点为 \(z\),白点为 \(y\)。路线 \(x \to z \to y\)\(x \to z \to x \to y\) 这两种情况的贡献是相同的。

假设我们不考虑 \(x \to z\) 相同的一段,那么也就是比较 \(z \to y\)\(z \to x \to y\) 的边权最大值,因为后者多饶了个弯,所以后者的贡献大于等于前者。

同理,我们也可以得出前者的贡献大于等于后者,所以前者的贡献就等于后者。

image-20251013210513798

所以我们分别考虑 \(x \to z\)\(x \to y\) 的贡献即可。

首先,我们可以先跑个多起点最短路算一下每个点到特殊点的距离最小值 \(dis_i\),时间复杂度 \(O(n \log n)\)

然后我们进行一个很巧妙的操作,从而将这个东西合并到我们刚才的操作中。我们不是要取这两者的最大值吗,发现这个好像和刚才的路径最大边权有点相似,考虑能否将这个到特殊点的距离也当做一个边权?

实际上是可以的,对每个点建一个新点,原点和新点之间的边权即为 \(dis_i\)

如果原点为黑点,我们设原点为无色点,新点为黑点。

否则原点为白点,新点为无色点。

image-20251013211902735

然后我们将这个新边加上即可进行统计。

总时间复杂度为 \(O(n \log n)\)

具体加加减减可以手模,具体实现看代码。

code

#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
#include <vector>
#include <queue>

using namespace std;

namespace michaele {

    typedef long long ll;

    const int N = 4e5 + 100, M = N << 1;

    int n, k;
    int h[N], ver[M], ne[M], e[M], tot;
    int kd[N], sta[N];
    int fa[N], siz[N][2];
    ll f[N][2], tag[N][2], dis[N];
    bool vis[N];
    vector <int> st[N];

    struct edge {
        int x, y;
        ll z;
        bool operator < (const edge &t) const {
            return z < t.z;
        }
    };
    vector <edge> E;

    void add (int x, int y, int z) {
        ver[ ++ tot] = y;
        ne[tot] = h[x];
        h[x] = tot;
        e[tot] = z;
    }

    int fin (int x) {
        return x == fa[x] ? x : fa[x] = fin (fa[x]);
    }

    void dijk () {
        priority_queue <pair <ll, int> > q;
        for (int i = 1; i <= k; i ++) {
            q.push ({0, sta[i]});
            dis[sta[i]] = 0;
        }
        while (q.size ()) {
            int x = q.top ().second;
            q.pop ();
            if (vis[x]) continue;
            vis[x] = 1;
            for (int i = h[x]; i; i = ne[i]) {
                int y = ver[i];
                if (dis[y] > max (dis[x], (ll)e[i])) {
                    dis[y] = max (dis[x], (ll)e[i]);
                    q.push ({-dis[y], y});
                }
            }
        }
    }

    void clear () {
        tot = 0;
        int size = (n + 5) << 1;
        memset (h, 0, size * 4);
        memset (vis, 0, size);
        memset (dis, 0x3f, size * 8);
        memset (siz, 0, size * 4 * 2);
        memset (tag, 0, size * 8 * 2);
        memset (f, 0, size * 8 * 2);
        E.clear ();
        for (int i = 1; i <= n * 2; i ++) {
            vector <int> emp;
            st[i].swap (emp);
        }
    }

    void solve () {
        cin >> n >> k;
        clear ();

        for (int i = 1; i <= n; i ++) {
            cin >> kd[i];
        }
        for (int i = 1; i < n; i ++) {
            int x, y, z;
            cin >> x >> y >> z;
            add (x, y, z);
            add (y, x, z);
            E.push_back ({x, y, z});
        }
        for (int i = 1; i <= k; i ++) {
            cin >> sta[i];
        }

        dijk ();

        for (int i = 1; i <= n; i ++) {
            fa[i] = i, fa[i + n] = i + n;
            st[i].push_back (i);
            st[i + n].push_back (i + n);
            if (kd[i]) siz[i + n][1] = 1;
            else siz[i][0] = 1;
            E.push_back ({i, i + n, dis[i]});
        }
        sort (E.begin(), E.end ());

        for (auto p : E) {
            int fx = fin (p.x), fy = fin (p.y);
            if (fx == fy) continue;
            if (siz[fx][0] + siz[fx][1] < siz[fy][0] + siz[fy][1]) {
                swap (fx, fy);
            }
            fa[fy] = fx;
            tag[fx][0] += (ll)p.z * siz[fy][0];
            tag[fx][1] += (ll)p.z * siz[fy][1];
            for (auto t : st[fy]) {
                f[t][0] += tag[fy][0] + (ll)p.z * siz[fx][0] - tag[fx][0];
                f[t][1] += tag[fy][1] + (ll)p.z * siz[fx][1] - tag[fx][1];
                st[fx].push_back (t);
            }
            siz[fx][0] += siz[fy][0];
            siz[fx][1] += siz[fy][1];
        }

        int ff = fin (1);

        for (int i = 1; i <= n * 2; i ++) {
            f[i][0] += tag[ff][0];
            f[i][1] += tag[ff][1];
        }

        ll ans = 0;
        for (int i = 1; i <= n; i ++) {
            if (kd[i]) ans += f[i + n][0];
        }
        cout << ans << endl;
        for (int i = 1; i <= n; i ++) {
            if (kd[i] == 0) {
                cout << ans - f[i][1] + f[i + n][0] - dis[i] << endl;
            } else {
                cout << ans - f[i + n][0] + f[i][1] - dis[i] << endl;
            }
        }
    }

    void Main () {
        int T;
        cin >> T;
        while (T --) {
            solve ();
        }
    }
}

int main () {

    michaele :: Main ();

    return 0;
}
posted @ 2025-10-16 15:04  michaele  阅读(7)  评论(0)    收藏  举报