题解:P14666 [KenOI 2025] 游走题
很好的数数题。
思路
观察样例,猜一个结论:游走的终点只可能是节点 \(1\)。考虑证明,一个节点如果往儿子走最终显然是可以再走回父亲的,但如果走到了父亲就不能再走回去了。所以只有走到一个没有父亲的点且把这个点的所有子树都走遍才会停止。这里的走遍指至少在这个子树中走过了一个节点。
得出这个结论之后,问题便转化为 \(1\) 到 \(s\) 的路径必走,同时可以走到其它节点,问走过本质不同的路径长度之和。因为 \(s\) 到 \(1\) 的路径只会走一遍,而其它的存在与路径上的点都需要走两遍,设路径总数为 \(num\)、其它存在于路径上的点的路径长度之和为 \(len\),则答案便等于:\(dis(1, s)\times num+2\times len\)。
考虑通过树形 dp 计算出 \(num\) 和 \(len\)。发现直接对整棵树进行 dp 并不容易,但如果 \(s=1\) 是很好计算的,于是想到拆贡献,最后再合并。可以先把 \(1\) 到 \(s\) 的路径上的边删掉,然后此时原树变成了一个森林,考虑对于每一个森林中的树进行树形 dp。对于每一个树都钦定在 \(1\) 到 \(s\) 的路径上的点 \(u\) 为根,然后每一棵树就等同于跑一遍 \(s=u\) 的树形 dp,这个东西可以直接套用 \(s=1\) 的情况。
设 \(f_u\) 表示以 \(u\) 为根的子树中的合法路径长度之和,\(g_u\) 表示以 \(u\) 为根的子树中的合法路径数量。
- 如果 \(u\neq 1\),则每一个子树可以选择走和不走,所以:
- 如果 \(u=1\),则每一个子树都必须走,所以:
接下来考虑 \(f\) 的转移。在 \(u\) 的子树中,\(f_v(v\in son_u)\) 可能和其它的 \(u\) 的儿子一起产生贡献,所以只需要知道其它儿子的组合的方案数就可以算出 \(f_u\)。同时,因为从 \(v\) 到 \(u\) 也会产生 \(1\) 的贡献,这个贡献的前提是 \(v\) 必选,其它的儿子可以任意组合。
- 如果 \(u\neq 1\),则每一个子树可以选择走和不走,转移为:
- 如果 \(u=1\),则每一个子树都必须走,转移为:
最后再把 \(1\) 到 \(s\) 的路径上的点合并即可。因为路径上的每一个点都必须选,所以是按照 \(u=1\) 的转移方式合并的。
时间复杂度:\(O(n\log V)\)。要特判一下 \(s=1\) 的情况。
做法死了,因为可能没有逆元,所以考虑在转移的时候对每一个节点都做一个前后缀积,这样可以避免逆元的使用。此时的时间复杂度严格线性。
代码
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using ld = long double;
using ull = unsigned long long;
using i128 = __int128;
using PII = pair<int, int>;
using PLL = pair<ll, ll>;
constexpr ll inf = (1ll << 62);
constexpr int N = 1e6 + 10;
template <typename T>
concept Can_bit = requires(T x) { x >>= 1; };
template <int MOD>
struct modint {
int val;
static int norm(const int& x) { return x < 0 ? x + MOD : x; }
static constexpr int get_mod() { return MOD; }
modint inv() const {
assert(val);
int a = val, b = MOD, u = 1, v = 0, t;
while (b > 0) t = a / b, swap(a -= t * b, b), swap(u -= t * v, v);
assert(b == 1);
return modint(u);
}
modint() : val(0) {}
modint(const int& m) : val(norm(m)) {}
modint(const long long& m) : val(norm(m % MOD)) {}
modint operator-() const { return modint(norm(-val)); }
bool operator==(const modint& o) { return val == o.val; }
bool operator<(const modint& o) { return val < o.val; }
modint& operator+=(const modint& o) { return val = (1ll * val + o.val) % MOD, *this; }
modint& operator-=(const modint& o) { return val = norm(1ll * val - o.val), *this; }
modint& operator*=(const modint& o) { return val = static_cast<int>(1ll * val * o.val % MOD), *this; }
modint& operator/=(const modint& o) { return *this *= o.inv(); }
modint& operator^=(const modint& o) { return val ^= o.val, *this; }
modint& operator>>=(const modint& o) { return val >>= o.val, *this; }
modint& operator<<=(const modint& o) { return val <<= o.val, *this; }
modint operator-(const modint& o) const { return modint(*this) -= o; }
modint operator+(const modint& o) const { return modint(*this) += o; }
modint operator*(const modint& o) const { return modint(*this) *= o; }
modint operator/(const modint& o) const { return modint(*this) /= o; }
modint operator^(const modint& o) const { return modint(*this) ^= o; }
modint operator>>(const modint& o) const { return modint(*this) >>= o; }
modint operator<<(const modint& o) const { return modint(*this) <<= o; }
friend std::istream& operator>>(std::istream& is, modint& a) {
long long v;
return is >> v, a.val = norm(v % MOD), is;
}
friend std::ostream& operator<<(std::ostream& os, const modint& a) { return os << a.val; }
friend std::string tostring(const modint& a) { return std::to_string(a.val); }
template <Can_bit T>
friend modint qpow(const modint& a, const T& b) {
assert(b >= 0);
modint x = a, res = 1;
for (T p = b; p; x *= x, p >>= 1)
if (p & 1) res *= x;
return res;
}
};
using M107 = modint<1000000007>;
using M998 = modint<998244353>;
constexpr int mod = 1e9 + 7;
using Mint = M107;
// constexpr mod = ...;
// using Mint = modint<mod>;
struct Fact {
std::vector<Mint> fact, factinv;
const int n;
Fact(const int& _n) : n(_n), fact(_n + 1, Mint(1)), factinv(_n + 1) {
for (int i = 1; i <= n; ++i) fact[i] = fact[i - 1] * i;
factinv[n] = fact[n].inv();
for (int i = n; i; --i) factinv[i - 1] = factinv[i] * i;
}
Mint C(const int& n, const int& k) {
if (n < 0 || k < 0 || n < k) return 0;
return fact[n] * factinv[k] * factinv[n - k];
}
Mint A(const int& n, const int& k) {
if (n < 0 || k < 0 || n < k) return 0;
return fact[n] * factinv[n - k];
}
};
int n, s;
vector<vector<int>> G(N);
vector<Mint> f(N), g(N);
vector<int> depth(N), path;
bitset<N> flag;
bool dfs1(int u, int fa) {
bool ok = false;
if (u == s) {
ok = true;
}
for (auto v : G[u]) {
if (v == fa) continue;
depth[v] = depth[fa] + 1;
ok |= dfs1(v, u);
}
if (ok) {
path.push_back(u);
}
return ok;
}
void dfs2(int u, int fa) {
g[u] = 1;
vector<Mint> pre(int(G[u].size())), suf(int(G[u].size()));
for (auto v : G[u]) {
if (v == fa || flag[v]) continue;
dfs2(v, u);
g[u] *= (g[v] + (!u ? 0 : 1));
}
for (int i = 0; i < G[u].size(); i++) {
pre[i] = (!i ? 1 : pre[i - 1]);
if (G[u][i] == fa || flag[G[u][i]]) continue;
pre[i] *= (g[G[u][i]] + (!u ? 0 : 1));
}
for (int i = G[u].size() - 1; i >= 0; i--) {
suf[i] = (i == G[u].size() - 1 ? 1 : suf[i + 1]);
if (G[u][i] == fa || flag[G[u][i]]) continue;
suf[i] *= (g[G[u][i]] + (!u ? 0 : 1));
}
for (int i = 0; i < G[u].size(); i++) {
int v = G[u][i];
if (v == fa || flag[v]) continue;
f[u] += (f[v] + g[v]) * (!i ? 1 : pre[i - 1]) * (i == G[u].size() - 1 ? 1 : suf[i + 1]);
}
}
void solve() {
cin >> n >> s;
s--;
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
u--;
v--;
G[u].push_back(v);
G[v].push_back(u);
}
dfs1(0, -1);
if (!s) {
dfs2(0, -1);
cout << f[0] * 2 << "\n";
return;
}
reverse(path.begin(), path.end());
for (int i = 0; i < path.size(); i++) {
flag[path[i]] = true;
}
dfs2(0, -1);
dfs2(s, -1);
for (int i = 1; i < path.size() - 1; i++) {
dfs2(path[i], -1);
}
Mint ans = 0, calc = 1;
vector<Mint> pre(int(path.size())), suf(int(path.size()));
for (int i = 0; i < path.size(); i++) {
calc *= g[path[i]];
pre[i] = (!i ? 1 : pre[i - 1]) * g[path[i]];
}
for (int i = path.size() - 1; i >= 0; i--) {
suf[i] = (i == path.size() - 1 ? 1 : suf[i + 1]) * g[path[i]];
}
for (int i = 0; i < path.size(); i++) {
ans += f[path[i]] * (!i ? 1 : pre[i - 1]) * (i == path.size() - 1 ? 1 : suf[i + 1]);
}
Mint num = int(path.size()) - 1;
cout << ans * 2 + num * calc << "\n";
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int t = 1;
// cin >> t;
while (t--) {
solve();
}
return 0;
}

浙公网安备 33010602011771号