Loading

ZR 25 summer D7T1 题解 | 树上问题,dp

传送门

标签:树上问题,dp

题意

给你一个树和若干树上路径,到达一个路径中的点就一定要走完整条路径。问走过的点的集合的种类。

思路

发现可以缩点,将若干相互连接的路径缩成一个点。

赛时:缩点不知道怎么缩,先往后想。

发现缩完之后很简单,简单树形 dp,记录每个点 \(u\)\(f,g\),分别是一个点子树内的答案和连接到这个点方案数。

\[f_u=\sum\limits _{v \in son_u} f_v + \prod\limits _{v \in son_u} (g_v + 1) \]

\[g_u=\prod\limits _{v \in son_u} (g_v + 1) \]

缩点时,我们可以记录一个点是否被上面的点用路径连接。

具体地,\(tag_u\) 表示有无路径连接 \((u,fa_u)\)

这个可以每次读入路径 \((u,v)\) 时:

\[tag_u \leftarrow tag_u + 1 \]

\[tag_v \leftarrow tag_v + 1 \]

\[tag_{lca(u,v)} \leftarrow tag_{lca(u,v)} -2 \]

记得 dfs 时加一下。

然后没了,这个鬼做法赛时被卡常 70 pts,但是良心出题人赛后改时限,喜提 AC。

代码

link

注:finfout 是题目提供的快读。

const int N = 2e6 + 5;
const int MXLOG = 23;
const int mod = 998244353;
int n, m;
int tag[N];
ll f[N], g[N];
vector<int> e[N], e2[N];
pair<int, int> gr[N]; int E;
int fa[N][MXLOG], dep[N];
int LOG;
// int st[N << 1][MXLOG];
// int eu[N << 1], dfn[N], tot;
// int lg[N << 1];
void dfs1(int u, int FA){
    dep[u] = dep[FA] + 1;
    fa[u][0] = FA;
    rep(j, 1, LOG){
        fa[u][j] = fa[fa[u][j - 1]][j - 1];
    }
    for(int v : e[u]){
        if(v == FA) continue;
        dfs1(v, u);
        // eu[++tot] = u;
    }
}
int LCA(int x, int y){
    if(dep[x] < dep[y]) swap(x, y);
    per(j, LOG, 0){
        if(dep[fa[x][j]] >= dep[y]){
            x = fa[x][j];
        }
    }
    if(x == y) return x;
    per(j, LOG, 0){
        if(fa[x][j] != fa[y][j]){
            x = fa[x][j];
            y = fa[y][j];
        }
    }
    return fa[x][0];
}
struct dsu{
    int fa[N];
    void init(int n) { rep(i, 1, n) fa[i] = i; }
    int find(int x) { return fa[x] == x ? x : fa[x] = find(fa[x]); }
    void merge(int x, int y){
        x = find(x), y = find(y);
        if(x == y) return;
        fa[x] = y;
    }
} dsu;
void dfs2(int u, int FA){
    for(int v : e[u]){
        if(v == FA) continue;
        dfs2(v, u);
        if(tag[v] > 0){
            dsu.merge(u, v);
        }
        tag[u] += tag[v];
    }
}
void dfs3(int u, int FA){
    f[u] = 0, g[u] = 1;
    for(int v : e2[u]){
        if(v == FA) continue;
        dfs3(v, u);
        f[u] = (f[u] + f[v]) % mod;
        g[u] = g[u] * (g[v] + 1) % mod;
    }
    f[u] = (f[u] + g[u]) % mod;
}
void solve_test_case(){
    fin >> n >> m;
    LOG = log2(n) + 1;
    rep(i, 1, n - 1){
        int u, v;
        fin >> u >> v;
        e[u].push_back(v);
        e[v].push_back(u);
        gr[i] = {u, v};
    }
    dfs1(1, 1);
    rep(i, 1, m){
        int x, y;
        fin >> x >> y;
        int lca = LCA(x, y);
        tag[x]++, tag[y]++, tag[lca] -= 2; 
    }
    dsu.init(n);
    dfs2(1, 1);
    rep(i, 1, n - 1){
        int u = gr[i].first, v = gr[i].second;
        u = dsu.find(u), v = dsu.find(v);
        if(u != v){
            e2[u].push_back(v);
            e2[v].push_back(u);
        }
    }
    int rt = dsu.find(1);
    dfs3(rt, rt);
    fout << f[rt];
}
signed main(){
    freopen("wolf.in", "r", stdin);
    freopen("wolf.out", "w", stdout);
    int Test_case_num = 1;
    while(Test_case_num--) solve_test_case();
    return 0;
}
posted @ 2025-08-12 15:58  lajishift  阅读(10)  评论(0)    收藏  举报