Loading

2023 ICPC 亚洲区域赛济南站 B. Graph Partitioning 2

前言

讲还是要多听, 这个很重要啊

思路

赛时的思路不太正确啊

容易想到树形 \(\rm{dp}\) , 考虑令 \(f_{u, i}\) 表示对于 \(u\) 子树, 拆分出一块包含 \(u\) 的大小为 \(i\) 的连通块的方案数

考虑转移, 类似树上背包

\[f_{u, k} \gets \sum_{v \in son(u) , i + j = k} f_{v, i} \times f_{u, j} \]

逐个插入即可, 树上背包的上下界优化证明了这个的复杂度为 \(\mathcal{O} (nk)\)

对于 \(k\) 比较小的情况下, 直接使用这个算法是正确的
那么 \(k\) 比较大的呢?
考虑 \(k\) 比较大的时候怎么简化运算, 实际上此时可以被使用的状态是很少的, 具体的原因是, 假设你分了 \(x\) 个大小为 \(k\) 的块, 当然也可以计算出大小为 \(k + 1\) 的块的个数 $\displaystyle y = \left\lfloor \frac{size - xk}{k + 1} \right\rfloor $ , 所以当前覆盖 \(u\) 的块的大小为 \((size - xk) \textrm{ mod } (k + 1)\) , 容易发现的是 \(x\) 数量很少啊, 所以对应的覆盖 \(u\) 的块的大小的数量也就很少啊
需要注意的是, \((size - xk) \textrm{ mod } (k + 1) = 0\) 对应了两种情况, 必须新开一维确定这个的大小

考虑根号分治

  • 对于 \(k \leq \sqrt{n}\)
    直接使用树形背包
  • 对于 \(k \geq \sqrt{n}\)
    记录分出大小为 \(k\) 的块的个数, 转移即可

考虑 \(k \geq \sqrt{n}\) 时的具体转移
\(g_{u, i, 0/1}\) 表示对于 \(u\) 子树, 拆分出了 \(i\) 块大小为 \(k\) 的连通块, 其中 \(u\) 所在连通块大小是否为 \(k + 1\) 的方案数
假设当前考虑到了儿子 \(v\) , 还是逐个插入
首先写出简单形式的转移

\[g_{u, k} \gets \sum_{v \in son(u), i + j = k} g_{u, i} \times g_{v, j} \]

考虑 \(k + 1\) 的约束

首先对于 $ (size - xk) \textrm{ mod } (k + 1) \neq 0$ 和 $ (size - xk) \textrm{ mod } (k + 1) = 0$ 但是钦定了大小不为 \(k + 1\) 的情况正常处理

\[g_{u, k, 0} \gets \sum_{v \in son(u), i + j = k} g_{u, i, 0} \times g_{v, j, 0} \]

然后处理特殊情况 $ (size - xk) \textrm{ mod } (k + 1) = 0$

\[\begin{align*} & g_{u, k, 1} \gets g_{u, i, 1} \times g_{v, j, 1} \\ & g_{u, k, 1} \gets g_{u, i, 0} \times g_{v, j, 0} \\ & g_{u, k, 1} \gets g_{u, i, 1} \times g_{v, j, 0} \\ & g_{u, k, 0} \gets g_{u, i, 0} \times g_{v, j, 1} \end{align*} \]

意义从上到下依次为

  • 两个 \(k + 1\)\(v\)\(k + 1\) 留下 \(k + 1\)
  • 两个块拼起来恰好为 \(k + 1\)
  • \(v\) 的连通块大小为 \(0\), \(u\) 的连通块大小为 \(k + 1\)
  • \(v\) 的连通块大小为 \(k + 1\) , 直接删除

合并结束后, 类似的, 将所有连通块大小为 \(k\) 的状态更新一遍

\[f_{u, i + 1, 0} \gets f_{u, i, 0} \]

实现

框架

实现需要比较精细, 对我来说不好搞

其他的就是根据上面的模拟
注意树上背包的上下界优化不要写挂了
本质上就是双双合并的 \(\rm{trick}\)

代码

破防了, 反正也用不到, 不调了

放一份大佬代码

#include <bits/stdc++.h>
#define _rep(i, x, y) for(int i = x; i <= y; ++i)
#define _req(i, x, y) for(int i = x; i >= y; --i)
#define _rev(i, u) for(int i = head[u]; i; i = e[i].nxt)
#define pb push_back
#define fi first
#define se second
#define mst(f, i) memset(f, i, sizeof f)
using namespace std;
#ifdef ONLINE_JUDGE
#define debug(...) 0
#else
#define debug(...) fprintf(stderr, __VA_ARGS__), fflush(stderr)
#endif
typedef long long ll;
typedef pair<int, int> PII;
namespace fastio{
    #ifdef ONLINE_JUDGE
    char ibuf[1 << 20],*p1 = ibuf, *p2 = ibuf;
    #define get() p1 == p2 && (p2 = (p1 = ibuf) + fread(ibuf, 1, 1 << 20, stdin), p1 == p2) ? EOF : *p1++
    #else
    #define get() getchar()
    #endif
    template<typename T> inline void read(T &t){
        T x = 0, f = 1;
        char c = getchar();
        while(!isdigit(c)){
            if(c == '-') f = -f;
            c = getchar();
        }
        while(isdigit(c)) x = x * 10 + c - '0', c = getchar();
        t = x * f;
    }
    template<typename T, typename ... Args> inline void read(T &t, Args&... args){
        read(t);
        read(args...);
    }
    template<typename T> void write(T t){
        if(t < 0) putchar('-'), t = -t;
        if(t >= 10) write(t / 10);
        putchar(t % 10 + '0');
    }
    template<typename T, typename ... Args> void write(T t, Args... args){
        write(t), putchar(' '), write(args...);
    }
    template<typename T> void writeln(T t){
        write(t);
        puts("");
    }
    template<typename T> void writes(T t){
        write(t), putchar(' ');
    }
    #undef get
};
using namespace fastio;
#define multitest() int T; read(T); _rep(tCase, 1, T)
namespace Calculation{
    const ll mod = 998244353;
    ll ksm(ll p, ll h){ll base = p % mod, res = 1; while(h){if(h & 1ll) res = res * base % mod; base = base * base % mod, h >>= 1ll;} return res;}
    void dec(ll &x, ll y){x = ((x - y) % mod + mod) % mod;}
    void add(ll &x, ll y){x = (x + y) % mod;}
    void mul(ll &x, ll y){x = x * y % mod;}
    ll sub(ll x, ll y){return ((x - y) % mod + mod) % mod;}
    ll pls(ll x, ll y){return ((x + y) % mod + mod) % mod;}
    ll mult(ll x, ll y){return x * y % mod;}
}
using namespace Calculation;
const int N = 1e5 + 5, B = 300, M = 405;
int n, k, siz[N];
vector<int> G[N];
namespace solution1{
    ll f[N][M], g[M];
    void dfs(int u, int fa){
        f[u][1] = 1, siz[u] = 1;
        for(auto &v : G[u]){
            if(v == fa) continue;
            dfs(v, u);
            _rep(i, 1, min(siz[u], k + 1)) _rep(j, 0, min(siz[v], k + 1 - i)){
                add(g[i + j], f[u][i] * f[v][j] % mod);
            }
            siz[u] += siz[v];
            _rep(i, 0, min(siz[u], k + 1)) f[u][i] = g[i], g[i] = 0;
        }
        add(f[u][0], f[u][k] + f[u][k + 1]), f[u][k + 1] = 0;
    }
    void solve(){
        dfs(1, 0);
        writeln(f[1][0]);
        _rep(i, 1, n) _rep(j, 0, k + 1) f[i][j] = 0; 
        _rep(i, 1, n) G[i].clear(), siz[i] = 0;
    }
}
namespace solution2{
    ll f[N][M][2], g[M][2];
    void dfs(int u, int fa){
        f[u][0][0] = 1, siz[u] = 1;
        for(auto &v : G[u]){
            if(v == fa) continue;
            dfs(v, u);
            _rep(i, 0, siz[u] / k) _rep(j, 0, siz[v] / k){
                _rep(p, 0, 1) _rep(q, 0, 1){
                    int x = (siz[u] - i * k) / (k + 1) - p, y = (siz[v] - j * k) / (k + 1) - q;
                    x = siz[u] - (k + 1) * x - i * k, y = siz[v] - (k + 1) * y - j * k;
                    if(x){
                        if(!p && !q && x + y == k + 1){
                            add(g[i + j][1], f[u][i][p] * f[v][j][q]);
                        }
                        if(p){
                            if(q || !y) add(g[i + j][1], f[u][i][p] * f[v][j][q]);
                        }
                        if(!p){
                            if(q) add(g[i + j][0], f[u][i][p] * f[v][j][q]);
                            if(!q && x + y <= k){
                                add(g[i + j][0], f[u][i][p] * f[v][j][q]);
                            }
                        }
                    }
                }
            }
            siz[u] += siz[v];
            _rep(i, 0, siz[u] / k) _rep(p, 0, 1) f[u][i][p] = g[i][p], g[i][p] = 0;
        }
        _rep(i, 0, siz[u] / k) if((siz[u] - i * k) % (k + 1) == k) add(f[u][i + 1][0], f[u][i][0]); 
    }
    void solve(){
        dfs(1, 0);
        ll ans = 0;
        _rep(i, 0, n / k){
            _rep(j, 0, 1){
                if(j || (n - i * k) % (k + 1) == 0) add(ans, f[1][i][j]);
            }
        }
        writeln(ans);
        _rep(i, 1, n) _rep(j, 0, n / k) _rep(p, 0, 1) f[i][j][p] = 0; 
        _rep(i, 1, n) G[i].clear(), siz[i] = 0;
    }
}
int main(){
    multitest(){
        read(n, k);
        _rep(i, 2, n){
            int u, v; read(u, v);
            G[u].pb(v), G[v].pb(u);
        }
        if(k <= B) solution1::solve();
        else{
            solution2::solve();
        }
    }    
    return 0;
}

总结

树上背包的常见合并方式, 正确性不显然, 但是搞得很明白了
注意这一类背包要求每一颗子树都要合并进来, 特殊的实现方式
误区在如果这棵树在实际上没有合并, 本质上是 \(f_{v, 0}\) 的方案数, 实际上还是在合并

稍微总结一下树上背包:
合并之后再做整体操作, 合并之前可以根据意义初始化

一个暴力算法无法通过时, 考虑根号分治
分治的另一个算法当然可以使用更好的性质

\(\rm{dp}\) 这种问题, 尽量将状态限制成确定的

分类讨论注意不重不漏

posted @ 2025-01-08 15:21  Yorg  阅读(203)  评论(0)    收藏  举报