题解:CF1051F

题解:CF1051F

提供一种奇怪的 dp 做法。

首先观察数据范围可以发现 \(m-n \le 20\),所以这是一张十分稀疏的图,肯定是要建出一棵树然后对返祖边进行操作,但是这个返祖边该怎么计算贡献呢?

假设现在建出了一颗生成树(随便的一颗),返祖边共有 \(tot\) 条,设 \(st_i,en_i,w_i\) 分别为第 \(i\) 条的起点、终点、边权。

那么可以发现从 \(x\)\(y\) 的一条路径一定可以表示为:

\(x\rightarrow st_i\rightarrow en_i\rightarrow st_j\rightarrow en_j\rightarrow\cdots\rightarrow st_k\rightarrow en_k\rightarrow y\)

也就是经过一些返祖边后到达 \(y\)

首先会有一个暴力做法,直接枚举经过的返祖边顺序,不过这是 \(O(q\times tot!)\) 的。

这时候就可以请出 dp 了。

因为其实中间经过了什么并不重要,只需要知道头和尾是什么就好(也就是上面的 \(st_i\)\(en_k\))。

而头和尾的总数量只有是 \(tot^2\) 个。

所以可以直接设 \(dp_{i,j}\) 表示从第 \(i\) 条边到第 \(j\) 条边的最短距离。

转移考虑枚举头尾,也就是:

\[dp_{i,j} \gets \min_{l=1}^{tot}{dp_{l,j}+dis(en_i,st_l)}+w_i \]

\[dp_{i,j} \gets \min_{l=1}^{tot}{dp_{i,l}+dis(en_l,st_j)}+w_j \]

\[dp_{i,i}=w_i \]

其中:\(dis(x,y)\) 表示 \(x\)\(y\) 在树上的距离。

但是发现这是有后效性的,所以可以加一维表示共有多少条边,变为:

\[dp_{i,j,k} \gets \min_{l=1}^{tot}{dp_{l,j,{k - 1}}+dis(en_i,st_l)}+w_i \]

\[dp_{i,j,k} \gets \min_{l=1}^{tot}{dp_{i,l,{k - 1}}+dis(en_l,st_j)}+w_j \]

\[dp_{i,i,k}=w_i \]

状态总共有 \(tot^3\) 种,转移需要 \(O(tot)\),预处理 \(lca\) 复杂度 \(O(tot^2\log n)\),所以总复杂度为 \(O(tot^4+tot^2\log n)\)

最后统计答案,直接枚举头尾,然后用最开始的路径表示方法计算即可。

这部分复杂度 \(O(q\times tot^2+q\times tot\times \log n)\)

使用 \(O(1) lca\) 可以更快。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#include <queue>
#include <set>
#include <map>

using namespace std;

const int N = 1e5 + 10, M = 50;

#define int long long
#define fi first
#define se second
#define lid id << 1
#define rid id << 1 | 1
#define emp emplace_back

using pii = pair <int, int>;
const int inf = 0x3f3f3f3f3f3f3f3f;

int dp[M][M][M], n, m, tot, st[M], en[M], w[M], son[N], sum[N], top[N], siz[N], dep[N], f[N], fa[N], pre[M][M];
vector <pii> G[N];

void dfs(int x, int fa)
{
    f[x] = fa, siz[x] = 1, dep[x] = dep[fa] + 1;
    for (auto [to, w] : G[x])
    {
        if (to == fa) continue;
        sum[to] = sum[x] + w;
        dfs(to, x);
        if (siz[to] > siz[son[x]]) son[x] = to;
        siz[x] += siz[to];
    }
}

void dfs2(int x, int fa)
{
    top[x] = fa;
    if (son[x]) dfs2(son[x], fa);
    for (auto [to, w] : G[x])
    {
        if (to == f[x] || to == son[x]) continue;
        dfs2(to, to);
    }
}

int lca(int x, int y)
{
    while (top[x] != top[y])
    {
        if (dep[top[x]] < dep[top[y]]) swap(x, y);
        x = f[top[x]];
    }
    return dep[x] > dep[y] ? y : x;
}

int dis(int x, int y)
{
    return sum[x] + sum[y] - 2 * sum[lca(x, y)];
}

int calc(int i, int j, int k)
{
    if (k <= 0) return inf;
    if (i == j) return (dp[i][j][k] = w[i]);
    if (dp[i][j][k]) return dp[i][j][k];
    int res = inf;
    for (int l = 1; l <= tot; l++)
    {
        res = min(res, calc(i, l, k - 1) + pre[l][j] + w[j]);
        res = min(res, calc(l, j, k - 1) + pre[i][l] + w[i]);
    }
    res = min(res, calc(i, j, k - 1));
    return (dp[i][j][k] = res);
}

void add(int x, int y, int w) {G[x].emp(y, w);}

int Find(int x)
{
    if (fa[x] == x) return fa[x];
    return fa[x] = Find(fa[x]);
}

int preDis[N][2][2], p[M];

struct S
{
    int x, y, w;
}e[N];

signed main()
{
    // freopen("data.in", "r", stdin); freopen("data.out", "w", stdout);
    ios :: sync_with_stdio(false), cin.tie(0), cout.tie(0);

    cin >> n >> m;
    for (int i = 1; i <= m; i++)
    {
        int x, y, w; cin >> x >> y >> w;
        e[i] = {x, y, w};
    }
    sort(e + 1, e + 1 + m, [](S x, S y)
    {
        return x.w < y.w;
    });
    for (int i = 1; i <= n; i++) fa[i] = i;
    for (int i = 1; i <= m; i++)
    {
        int x = e[i].x, y = e[i].y;
        int xx = Find(x), yy = Find(y);
        if (xx == yy)
        {
            ++tot;
            st[tot] = x, en[tot] = y, w[tot] = e[i].w;
            p[tot] = tot + 1;
            ++tot;
            en[tot] = x, st[tot] = y, w[tot] = e[i].w;
            p[tot] = tot - 1;
            continue;
        }
        add(x, y, e[i].w), add(y, x, e[i].w);
        fa[xx] = yy;
    }
    dfs(1, 0), dfs2(1, 1);
    for (int i = 1; i <= tot; i++) for (int j = 1; j <= tot; j++) pre[i][j] = dis(en[i], st[j]);
    for (int i = 1; i <= tot; i++) for (int j = 1; j <= tot; j++) for (int k = 1; k <= tot; k++) calc(i, j, k);
    int q; cin >> q;
    for (int i = 1; i <= q; i++)
    {
        int x, y, res = inf; cin >> x >> y;
        res = dis(x, y);
        for (int j = 1; j <= tot; j++)
        {
            preDis[j][0][0] = dis(x, st[j]);
            preDis[j][1][1] = dis(y, en[j]);
        }
        for (int j = 1; j <= tot; j++) for (int k = 1; k <= tot; k++)
        {
            res = min(res, preDis[j][0][0] + dp[j][k][tot] + preDis[k][1][1]);
        }
        cout << res << '\n';
    }

    return 0;
}
posted @ 2025-08-24 11:49  QEDQEDQED  阅读(11)  评论(0)    收藏  举报