Luogu P14260 期待(counting) 题解 [ 蓝 ] [ 前缀和 ] [ 组合计数 ]

期待:想了大概 15min,写用了 1h 左右。这题按照部分分一步一步去想应该是不难出正解的,难点应该在于实现上。

看到题感觉不太好直接入手,于是先考虑特殊性质。特殊性质 A 的做法比较神秘,特殊性质 B 就是个骗分的,没啥启发性。

而特殊性质 C 是真正对正解有帮助的部分分。从链的角度考虑,可以把两个必经点 \(a, b\) 在链上标出来,然后很显然可以枚举 \(\bm{u, v}\) 之间的长度,根据 \(u, v\) 谁在左、谁在右分类讨论,算出 \(u, v\) 的取值区间,乘法原理计算即可。此处同样有一个简化计数流程的观察:大多数的移动方案都是成对出现的,也就是说大多数情况下我们只需要对正向走的方案计数一次,反向的不用算,直接将正向的答案乘 \(2\) 即可。

由特殊性质的做法,启发我们通过枚举 \(u, v\) 之间的距离 \(d\) 进行计数。考虑正解,这里着重对计数过程讲解:

因为是无根树,为了方便刻画,我们强制将 \(\bm a\) 钦定为树根。且在下文中,假设 \(u\) 的必经点为 \(b\)\(v\) 的必经点为 \(a\)\(u, v\) 最后的位置为 \(u', v'\)\(c\) 表示同时为 \(b\) 的祖先和 \(a\) 的儿子的节点,\(T\) 表示原树删掉 \(c\) 的子树后剩下的树。

\(u, v\) 的方位进行讨论,并钦定向上走为正方向

  • \(u\) 在下,\(v\) 在上:对 \(u, v'\) 计数。
    • 需要满足 \(dep_u\ge d\),因为 \(v\) 一旦不是 \(u\) 的祖先了,则向上走会使得 \(v\) 一直无法与 \(a\) 重合。
    • 需要满足 \(dep_u\ge dep_b\),因为是向上走,\(u\) 想要和 \(b\) 重合就必须在 \(b\) 子树内。
    • 需要满足 \(dep_{v'}\ge \max\{d - dep_b, 0\}\)。其中 \(\max\{d - dep_b, 0\}\) 的含义是当 \(v\)\(a\) 重合时,\(u\)\(b\) 重合所需的最少步数。这个限制是因为只有 \(u, v\) 都满足要求了才是一个合法的方案。
  • \(u\) 在上,\(v\) 在下:对 \(u', v\) 计数。
    • 需要满足 \(dep_v - d \ge dep_b\)。其中 \(dep_v - d = dep_u\)。因为只有 \(u\)\(b\) 的子树内,向上走的时候才能有重合。
    • 需要满足 \(dep_{u'}\ge d\)。因为当 \(v\)\(a\) 重合时,\(u\) 会往 \(T\) 内延伸 \(d\) 的长度。

发现我们只需要用到 \(b\) 子树内、\(T\) 内的深度信息。于是可以分别对这两棵树做 DFS,然后把所有节点的深度扔进一个里,统计的时候运用前缀和、乘法原理计数即可。

发现还是过不了,手模第一个和第二个小样例可以发现两个 corner case:

  • \(a, b\) 相差为 \(1\) 的时候,会多一个方案:\((a, b)\to (b, a)\)
  • \(a, b\) 相差大于等于 \(2\) 的时候,\((a, b)\) 可能被记录了两次,需要减掉一次。

判掉这两个 corner case 即可通过,时间复杂度 \(O(n)\)

#include <bits/stdc++.h>
#define fi first
#define se second
#define eb(x) emplace_back(x)
#define pb(x) push_back(x)
#define lc(x) (tr[x].ls)
#define rc(x) (tr[x].rs)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
using pi = pair<int, int>;
const int N = 100005;
int n, a, b, mxdep, father[N], dep[N], sz[N], tot[N], smtot[N];
vector<int> g[N];
ll ans;
void dfs1(int u, int fa) // 求深度、父节点、子树大小
{
    father[u] = fa;
    sz[u] = 1;
    for(auto v : g[u])
    {
        if(v == fa) continue;
        dep[v] = dep[u] + 1;
        dfs1(v, u);
        sz[u] += sz[v];
    }
}
void dfs2(int u, int fa) // 求 b 子树内的桶
{
    tot[dep[u]]++;
    for(auto v : g[u])
    {
        if(v == fa) continue;
        dfs2(v, u);
    }
}
void dfs3(int u, int ban) // 求 T 内的桶
{
    smtot[dep[u]]++;
    for(auto v : g[u])
    {
        if(v == father[u] || v == ban) continue;
        dfs3(v, ban);
    }
}
void solve()
{
    cin >> n >> a >> b;
    ans = 0;
    for(int i = 1; i <= n; i++)
        g[i].clear();
    for(int i = 1; i < n; i++)
    {
        int u, v;
        cin >> u >> v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dep[a] = 0;
    memset(tot, 0, sizeof(tot));
    memset(smtot, 0, sizeof(smtot));
    dfs1(a, 0);
    dfs2(b, father[b]);
    int now = b;
    while(father[now] != a)
        now = father[now];
    dfs3(a, now);
    for(int i = 1; i <= n; i++)
    {
        tot[i] += tot[i - 1];
        smtot[i] += smtot[i - 1];
    }
    for(int d = 0; d < n; d++)
    {
        // u 下 v 上
        int xj = max(dep[b], d);
        int rj = max(0, d - dep[b]);
        ans += 2ll * (tot[n] - tot[xj - 1]) * (smtot[n] - (rj == 0 ? 0 : smtot[rj - 1]));
        // u 上 v 下
        if(d == 0) continue;
        xj = dep[b] + d;
        ans += 2ll * (tot[n] - tot[xj - 1]) * (smtot[n] - (d == 0 ? 0 : smtot[d - 1]));
    }
    if(dep[b] == 1) ans++;
    if(dep[b] >= 2) ans--;
    cout << ans << "\n";
}
int main()
{
    // freopen("counting5.in", "r", stdin);
    // freopen("sample.out", "w", stdout);
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    int t;
    cin >> t;
    while(t--) solve();
    return 0;
}
posted @ 2025-10-20 01:44  KS_Fszha  阅读(17)  评论(0)    收藏  举报