Luogu P14260 期待(counting) 题解 [ 蓝 ] [ 前缀和 ] [ 组合计数 ]
期待:想了大概 15min,写用了 1h 左右。这题按照部分分一步一步去想应该是不难出正解的,难点应该在于实现上。
看到题感觉不太好直接入手,于是先考虑特殊性质。特殊性质 A 的做法比较神秘,特殊性质 B 就是个骗分的,没啥启发性。
而特殊性质 C 是真正对正解有帮助的部分分。从链的角度考虑,可以把两个必经点 \(a, b\) 在链上标出来,然后很显然可以枚举 \(\bm{u, v}\) 之间的长度,根据 \(u, v\) 谁在左、谁在右分类讨论,算出 \(u, v\) 的取值区间,乘法原理计算即可。此处同样有一个简化计数流程的观察:大多数的移动方案都是成对出现的,也就是说大多数情况下我们只需要对正向走的方案计数一次,反向的不用算,直接将正向的答案乘 \(2\) 即可。
由特殊性质的做法,启发我们通过枚举 \(u, v\) 之间的距离 \(d\) 进行计数。考虑正解,这里着重对计数过程讲解:
因为是无根树,为了方便刻画,我们强制将 \(\bm a\) 钦定为树根。且在下文中,假设 \(u\) 的必经点为 \(b\),\(v\) 的必经点为 \(a\);\(u, v\) 最后的位置为 \(u', v'\);\(c\) 表示同时为 \(b\) 的祖先和 \(a\) 的儿子的节点,\(T\) 表示原树删掉 \(c\) 的子树后剩下的树。
对 \(u, v\) 的方位进行讨论,并钦定向上走为正方向:
- \(u\) 在下,\(v\) 在上:对 \(u, v'\) 计数。
- 需要满足 \(dep_u\ge d\),因为 \(v\) 一旦不是 \(u\) 的祖先了,则向上走会使得 \(v\) 一直无法与 \(a\) 重合。
- 需要满足 \(dep_u\ge dep_b\),因为是向上走,\(u\) 想要和 \(b\) 重合就必须在 \(b\) 子树内。
- 需要满足 \(dep_{v'}\ge \max\{d - dep_b, 0\}\)。其中 \(\max\{d - dep_b, 0\}\) 的含义是当 \(v\) 与 \(a\) 重合时,\(u\) 与 \(b\) 重合所需的最少步数。这个限制是因为只有 \(u, v\) 都满足要求了才是一个合法的方案。
- \(u\) 在上,\(v\) 在下:对 \(u', v\) 计数。
- 需要满足 \(dep_v - d \ge dep_b\)。其中 \(dep_v - d = dep_u\)。因为只有 \(u\) 在 \(b\) 的子树内,向上走的时候才能有重合。
- 需要满足 \(dep_{u'}\ge d\)。因为当 \(v\) 与 \(a\) 重合时,\(u\) 会往 \(T\) 内延伸 \(d\) 的长度。
发现我们只需要用到 \(b\) 子树内、\(T\) 内的深度信息。于是可以分别对这两棵树做 DFS,然后把所有节点的深度扔进一个桶里,统计的时候运用前缀和、乘法原理计数即可。
发现还是过不了,手模第一个和第二个小样例可以发现两个 corner case:
- 当 \(a, b\) 相差为 \(1\) 的时候,会多一个方案:\((a, b)\to (b, a)\)。
- 当 \(a, b\) 相差大于等于 \(2\) 的时候,\((a, b)\) 可能被记录了两次,需要减掉一次。
判掉这两个 corner case 即可通过,时间复杂度 \(O(n)\)。
#include <bits/stdc++.h>
#define fi first
#define se second
#define eb(x) emplace_back(x)
#define pb(x) push_back(x)
#define lc(x) (tr[x].ls)
#define rc(x) (tr[x].rs)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
using pi = pair<int, int>;
const int N = 100005;
int n, a, b, mxdep, father[N], dep[N], sz[N], tot[N], smtot[N];
vector<int> g[N];
ll ans;
void dfs1(int u, int fa) // 求深度、父节点、子树大小
{
father[u] = fa;
sz[u] = 1;
for(auto v : g[u])
{
if(v == fa) continue;
dep[v] = dep[u] + 1;
dfs1(v, u);
sz[u] += sz[v];
}
}
void dfs2(int u, int fa) // 求 b 子树内的桶
{
tot[dep[u]]++;
for(auto v : g[u])
{
if(v == fa) continue;
dfs2(v, u);
}
}
void dfs3(int u, int ban) // 求 T 内的桶
{
smtot[dep[u]]++;
for(auto v : g[u])
{
if(v == father[u] || v == ban) continue;
dfs3(v, ban);
}
}
void solve()
{
cin >> n >> a >> b;
ans = 0;
for(int i = 1; i <= n; i++)
g[i].clear();
for(int i = 1; i < n; i++)
{
int u, v;
cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
dep[a] = 0;
memset(tot, 0, sizeof(tot));
memset(smtot, 0, sizeof(smtot));
dfs1(a, 0);
dfs2(b, father[b]);
int now = b;
while(father[now] != a)
now = father[now];
dfs3(a, now);
for(int i = 1; i <= n; i++)
{
tot[i] += tot[i - 1];
smtot[i] += smtot[i - 1];
}
for(int d = 0; d < n; d++)
{
// u 下 v 上
int xj = max(dep[b], d);
int rj = max(0, d - dep[b]);
ans += 2ll * (tot[n] - tot[xj - 1]) * (smtot[n] - (rj == 0 ? 0 : smtot[rj - 1]));
// u 上 v 下
if(d == 0) continue;
xj = dep[b] + d;
ans += 2ll * (tot[n] - tot[xj - 1]) * (smtot[n] - (d == 0 ? 0 : smtot[d - 1]));
}
if(dep[b] == 1) ans++;
if(dep[b] >= 2) ans--;
cout << ans << "\n";
}
int main()
{
// freopen("counting5.in", "r", stdin);
// freopen("sample.out", "w", stdout);
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int t;
cin >> t;
while(t--) solve();
return 0;
}

浙公网安备 33010602011771号