ABC221F Diameter set 题解
题意简述:
给定一棵 \(n\) 个节点的树,设它的直径是 \(D\),问有多少个集合满足集合中每两个点的距离都为 \(D\)。
\(\texttt{Data Range:} 1\le n\le 2\times 10^5\)。
考虑直径的性质:
- 树的每一条直径一定都经过一个公共点 / 一条公共边。经过的是点还是边取决于直径的长度是奇数还是偶数。
那么按照直径长度的奇偶性分类讨论,直接计算即可。
具体的:
- 若直径长度为偶数,设中点为 \(mid\),答案为 \(\prod\limits_{v\in son_{mid}}(cnt_v+1)-scnt-1\),其中 \(cnt_v\) 为 \(v\) 子树中距离 \(v\) 长度为 \(\frac{D}{2}-1\) 的个数,\(scnt\) 为距离 \(mid\) 长度为 \(\frac{D}{2}\) 的点的个数。
- 若直径长度为奇数,设中间的边为 \((mid,fmid)\),那么答案就是 \(cnt_{mid}\times cnt_{fmid}\),\(cnt_{mid}\) 为\(mid\) 子树中距离 \(mid\) 长度为 \(\frac{D}{2}\) 的点的个数。
代码:
#include <bits/stdc++.h>
#define DC int T = gi <int> (); while (T--)
#define DEBUG fprintf(stderr, "Passing [%s] line %d\n", __FUNCTION__, __LINE__)
#define File(x) freopen(x".in","r",stdin); freopen(x".out","w",stdout)
#define fi first
#define se second
#define pb push_back
#define mp make_pair
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair <int, int> PII;
typedef pair <LL, LL> PLL;
template <typename T>
inline T gi()
{
T x = 0, f = 1; char c = getchar();
while (c < '0' || c > '9') {if (c == '-') f = -1; c = getchar();}
while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return f * x;
}
const int N = 200003, M = N << 1, mod = 998244353;
int n;
int tot, head[N], ver[M], nxt[M];
int fa[N];
int mx, lft, rght;
int cnt;
inline void add(int u, int v) {ver[++tot] = v, nxt[tot] = head[u], head[u] = tot;}
inline int qpow(int x, int y)
{
int res = 1;
while (y)
{
if (y & 1) res = 1ll * res * x % mod;
x = 1ll * x * x % mod, y >>= 1;
}
return res;
}
void dfs(int u, int f, int dis)
{
if (dis > mx) mx = dis, rght = u;
fa[u] = f;
for (int i = head[u]; i; i = nxt[i])
{
int v = ver[i];
if (v == f) continue;
dfs(v, u, dis + 1);
}
}
void dfsson(int u, int f, int tar, int dis)
{
if (tar == dis) ++cnt;
for (int i = head[u]; i; i = nxt[i])
{
int v = ver[i];
if (v == f) continue;
dfsson(v, u, tar, dis + 1);
}
}
int main()
{
//freopen(".in", "r", stdin); freopen(".out", "w", stdout);
n = gi <int> ();
for (int i = 1; i < n; i+=1)
{
int u = gi <int> (), v = gi <int> ();
add(u, v), add(v, u);
}
dfs(1, 0, 0);
lft = rght, mx = 0;
dfs(lft, 0, 0);
int d = mx;
if (d % 2 == 0)
{
int mid = rght;
for (int i = 1; i <= d / 2; i+=1) mid = fa[mid];
int ans = 1, scnt = 0;
for (int i = head[mid]; i; i = nxt[i])
{
int v = ver[i];
cnt = 0;
dfsson(v, mid, d / 2 - 1, 0);
ans = 1ll * ans * (cnt + 1) % mod;
scnt += cnt;
}
printf("%d\n", (ans - 1 - scnt + mod) % mod);
}
else
{
int mid = rght;
for (int i = 1; i <= d / 2; i+=1) mid = fa[mid];
int fmid = fa[mid];
dfsson(mid, fmid, d / 2, 0);
int tcnt = cnt; cnt = 0;
dfsson(fmid, mid, d / 2, 0);
printf("%lld\n", 1ll * cnt * tcnt % mod);
}
return !!0;
}