ZR 25 summer D7T1 题解 | 树上问题,dp
标签:树上问题,dp
题意
给你一个树和若干树上路径,到达一个路径中的点就一定要走完整条路径。问走过的点的集合的种类。
思路
发现可以缩点,将若干相互连接的路径缩成一个点。
赛时:缩点不知道怎么缩,先往后想。
发现缩完之后很简单,简单树形 dp,记录每个点 \(u\) 的 \(f,g\),分别是一个点子树内的答案和连接到这个点方案数。
\[f_u=\sum\limits _{v \in son_u} f_v + \prod\limits _{v \in son_u} (g_v + 1)
\]
\[g_u=\prod\limits _{v \in son_u} (g_v + 1)
\]
缩点时,我们可以记录一个点是否被上面的点用路径连接。
具体地,\(tag_u\) 表示有无路径连接 \((u,fa_u)\)。
这个可以每次读入路径 \((u,v)\) 时:
\[tag_u \leftarrow tag_u + 1
\]
\[tag_v \leftarrow tag_v + 1
\]
\[tag_{lca(u,v)} \leftarrow tag_{lca(u,v)} -2
\]
记得 dfs 时加一下。
然后没了,这个鬼做法赛时被卡常 70 pts,但是良心出题人赛后改时限,喜提 AC。
代码
注:fin
、fout
是题目提供的快读。
const int N = 2e6 + 5;
const int MXLOG = 23;
const int mod = 998244353;
int n, m;
int tag[N];
ll f[N], g[N];
vector<int> e[N], e2[N];
pair<int, int> gr[N]; int E;
int fa[N][MXLOG], dep[N];
int LOG;
// int st[N << 1][MXLOG];
// int eu[N << 1], dfn[N], tot;
// int lg[N << 1];
void dfs1(int u, int FA){
dep[u] = dep[FA] + 1;
fa[u][0] = FA;
rep(j, 1, LOG){
fa[u][j] = fa[fa[u][j - 1]][j - 1];
}
for(int v : e[u]){
if(v == FA) continue;
dfs1(v, u);
// eu[++tot] = u;
}
}
int LCA(int x, int y){
if(dep[x] < dep[y]) swap(x, y);
per(j, LOG, 0){
if(dep[fa[x][j]] >= dep[y]){
x = fa[x][j];
}
}
if(x == y) return x;
per(j, LOG, 0){
if(fa[x][j] != fa[y][j]){
x = fa[x][j];
y = fa[y][j];
}
}
return fa[x][0];
}
struct dsu{
int fa[N];
void init(int n) { rep(i, 1, n) fa[i] = i; }
int find(int x) { return fa[x] == x ? x : fa[x] = find(fa[x]); }
void merge(int x, int y){
x = find(x), y = find(y);
if(x == y) return;
fa[x] = y;
}
} dsu;
void dfs2(int u, int FA){
for(int v : e[u]){
if(v == FA) continue;
dfs2(v, u);
if(tag[v] > 0){
dsu.merge(u, v);
}
tag[u] += tag[v];
}
}
void dfs3(int u, int FA){
f[u] = 0, g[u] = 1;
for(int v : e2[u]){
if(v == FA) continue;
dfs3(v, u);
f[u] = (f[u] + f[v]) % mod;
g[u] = g[u] * (g[v] + 1) % mod;
}
f[u] = (f[u] + g[u]) % mod;
}
void solve_test_case(){
fin >> n >> m;
LOG = log2(n) + 1;
rep(i, 1, n - 1){
int u, v;
fin >> u >> v;
e[u].push_back(v);
e[v].push_back(u);
gr[i] = {u, v};
}
dfs1(1, 1);
rep(i, 1, m){
int x, y;
fin >> x >> y;
int lca = LCA(x, y);
tag[x]++, tag[y]++, tag[lca] -= 2;
}
dsu.init(n);
dfs2(1, 1);
rep(i, 1, n - 1){
int u = gr[i].first, v = gr[i].second;
u = dsu.find(u), v = dsu.find(v);
if(u != v){
e2[u].push_back(v);
e2[v].push_back(u);
}
}
int rt = dsu.find(1);
dfs3(rt, rt);
fout << f[rt];
}
signed main(){
freopen("wolf.in", "r", stdin);
freopen("wolf.out", "w", stdout);
int Test_case_num = 1;
while(Test_case_num--) solve_test_case();
return 0;
}