T1: 树的直径(二)

考虑由关键点构成的虚树,答案一定是 \(\lceil\frac{虚树直径}{2}\rceil\)

由于关键点中深度最深的点一定是直径的某个端点,所以只需找到这个点,然后遍历其他点,通过lca求出两点间的距离,取最大值即可。

也可以跑两遍dfs求虚树直径,具体做法如下:

从任意一个点(不需要是关键点)开始第一遍dfs,求出关键点中距离这个点最远的点 \(u\)(如果有多个,任取一个)。再从 \(u\) 开始进行第二遍dfs,求出关键点中距离 \(u\) 最远的点 \(v\)(如果有多个,任取一个)。则 \((u, v)\) 成为虚树的一条直径。

注:这里没必要构建出虚树

代码实现1
#include <bits/stdc++.h>
#define rep(i, n) for (int i = 0; i < (n); ++i)

using namespace std;

template<typename T>
struct lca {
    int n, l;
    vector<vector<int>> to;
    vector<vector<T>> co;
    vector<int> dep;
    vector<T> costs;
    vector<vector<int>> par;
    lca(int n): n(n), to(n), co(n), dep(n), costs(n) {
        l = 0;
        while (1<<l <= n) ++l;
        par = vector<vector<int>>(n, vector<int>(l, -1));
    }
    void addEdge(int a, int b, T c=1) {
        to[a].push_back(b); co[a].push_back(c);
        to[b].push_back(a); co[b].push_back(c);
    }
    void dfs(int v, int d=0, T c=0, int p=-1) {
        par[v][0] = p;
        dep[v] = d;
        costs[v] = c;
        rep(i, to[v].size()) {
            int u = to[v][i];
            if (u == p) continue;
            dfs(u, d+1, c+co[v][i], v);
        }
    }
    void init(int root=0) {
        dfs(root);
        rep(i, l-1) {
            rep(v, n) {
                par[v][i+1] = par[v][i]==-1 ? -1 : par[par[v][i]][i];
            }
        }
    }
    // LCA
    int operator()(int a, int b) {
        if (dep[a] > dep[b]) swap(a, b);
        int gap = dep[b]-dep[a];
        for (int i = l-1; i >= 0; --i) {
            int len = 1<<i;
            if (gap >= len) {
                gap -= len;
                b = par[b][i];
            }
        }
        if (a == b) return a;
        for (int i = l-1; i >= 0; --i) {
            int na = par[a][i];
            int nb = par[b][i];
            if (na != nb) a = na, b = nb;
        }
        return par[a][0];
    }
    int length(int a, int b) {
        int c = this->operator()(a, b);
        return dep[a]+dep[b]-dep[c]*2;
    }
    T dist(int a, int b) {
        int c = this->operator()(a, b);
        return costs[a]+costs[b]-costs[c]*2;
    }
};

void solve() {
    int n, k;
    cin >> n >> k;
    
    vector<int> vs(k);
    rep(i, k) cin >> vs[i], vs[i]--;
    
    lca<int> g(n);
    rep(i, n-1) {
        int u, v;
        cin >> u >> v;
        --u; --v;
        g.addEdge(u, v);
    }
    g.init();
    
    int maxd = -1, a = -1;
    for (int v : vs) {
        if (g.dep[v] > maxd) {
            maxd = g.dep[v];
            a = v;
        } 
    }
    
    int ans = 0;
    for (int v : vs) {
        ans = max(ans, g.dist(a, v));
    }
    ans = (ans+1)/2;
    cout << ans << '\n';
}

int main() {
    int t;
    cin >> t;
    
    while (t--) solve();
    
    return 0;
}
代码实现2
#include <bits/stdc++.h>
#define rep(i, n) for (int i = 0; i < (n); ++i)

using namespace std;
using P = pair<int, int>;

void solve() {
    int n, k;
    cin >> n >> k;
    
    vector<bool> selected(n);
    rep(i, k) {
        int a;
        cin >> a;
        --a;
        selected[a] = true;
    }
    
    vector<vector<int>> to(n);
    rep(i, n-1) {
        int u, v;
        cin >> u >> v;
        --u; --v;
        to[u].push_back(v);
        to[v].push_back(u);
    }
    
    auto dfs = [&](auto& f, int v, int d=0, int p=-1) -> P {
        auto res = selected[v] ? P(d, v) : P(0, -1);
        for (int u : to[v]) {
            if (u == p) continue;
            res = max(res, f(f, u, d+1, v));
        }
        return res;
    };
    
    int a = dfs(dfs, 0).second;
    int diameter = dfs(dfs, a).first;
    int ans = (diameter+1)/2;
    cout << ans << '\n';
}

int main() {
    int t;
    cin >> t;
    
    while (t--) solve();
    
    return 0;
}

T2:简单 MST

对于 \(g = 1, 2, \cdots, r\),遍历 \([l, r]\)\(g\) 的所有倍数,找到 \(w\) 最小的那个,然后将它和剩下的倍数连边,接下来求MST即可

代码实现
#include <bits/stdc++.h>

using namespace std;
using ll = long long;

struct UnionFind {
	vector<int> d;
	UnionFind(int n = 0): d(n, -1) {}
	int find(int x) {
		if (d[x] < 0) return x;
		return d[x] = find(d[x]);
	}
	bool unite(int x, int y) {
		x = find(x); y = find(y);
		if (x == y) return false;
		if (d[x] > d[y]) swap(x, y);
		d[x] += d[y];
		d[y] = x;
		return true;
 	}
 	bool same(int x, int y) {
 		return find(x) == find(y);
	}
	int size(int x) {
		return -d[find(x)];
	}
};

void solve() {
    int l, r;
    cin >> l >> r;
    
    vector<int> w(r+1);
    for (int i = 2; i <= r; ++i) {
        if (w[i]) continue;
        for (int j = i; j <= r; j += i) {
            w[j]++;
        }
    }
    
    vector<vector<pair<int, int>>> es(15);
    for (int g = 1; g <= r; ++g) {
        int x = (l+g-1)/g*g;
        int a = x;
        for (int i = x; i <= r; i += g) {
            if (w[a] > w[i]) {
                a = i;
            }
        }
        for (int b = x; b <= r; b += g) {
            int c = w[a]+w[b]-w[gcd(a, b)];
            es[c].emplace_back(a, b);
        }
    }
    ll ans = 0;
    UnionFind uf(r+1);
    for (int c = 1; c <= 14; ++c) {
        for (auto [a, b] : es[c]) {
            if (uf.unite(a, b)) ans += c;
        }
    }
    
    cout << ans << '\n';
}

int main() {
    int t;
    cin >> t;
    
    while (t--) solve();
    
    return 0;
}

T3: 序列切割

\(p(l, r, i)\) 表示操作 \(i\) 次后只剩下 \(a[l, r]\) 的概率,则答案为 \(\sum p(l, r, k) \times (a_l + a_{l+1} + \cdots + a_r)\)

初值为 \(p(1, n, 0) = 1\)\(p(l, r, 0) = 0\)

对于 \(i \geqslant 1\),考虑转移:

\(S_i\)L\(p(l, r, i) = \sum\limits_{x=r+1}^n p(l, x, i-1) \cdot \frac{1}{x-l}\)

\(l+r\),还有一个额外的 \(p(l, l, i-1)\) 贡献表示单元素序列不可分割

\(S_i\)R,转移是类似的。

此时状态数为 \(\mathcal{O}(n^2k)\),转移次数为 \(\mathcal{O}(n)\),总复杂度为 \(\mathcal{O}(n^3k)\)

前缀和优化即可。\(\mathcal{O}(n^2k)\)

代码实现
#include <bits/stdc++.h>
#define rep(i, n) for (int i = 0; i < (n); ++i)

using namespace std;
using ll = long long;

//const int mod = 998244353;
const int mod = 1000000007;
struct mint {
    ll x;
    mint(ll x=0):x((x%mod+mod)%mod) {}
    mint operator-() const {
        return mint(-x);
    }
    mint& operator+=(const mint a) {
        if ((x += a.x) >= mod) x -= mod;
        return *this;
    }
    mint& operator-=(const mint a) {
        if ((x += mod-a.x) >= mod) x -= mod;
        return *this;
    }
    mint& operator*=(const mint a) {
        (x *= a.x) %= mod;
        return *this;
    }
    mint operator+(const mint a) const {
        return mint(*this) += a;
    }
    mint operator-(const mint a) const {
        return mint(*this) -= a;
    }
    mint operator*(const mint a) const {
        return mint(*this) *= a;
    }
    mint pow(ll t) const {
        if (!t) return 1;
        mint a = pow(t>>1);
        a *= a;
        if (t&1) a *= *this;
        return a;
    }

    // for prime mod
    mint inv() const {
        return pow(mod-2);
    }
    mint& operator/=(const mint a) {
        return *this *= a.inv();
    }
    mint operator/(const mint a) const {
        return mint(*this) /= a;
    }
};
istream& operator>>(istream& is, mint& a) {
    return is >> a.x;
}
ostream& operator<<(ostream& os, const mint& a) {
    return os << a.x;
}

mint inv[505];

void solve() {
    int n, k;
    cin >> n >> k;
    
    vector<int> a(n);
    rep(i, n) cin >> a[i];
    
    string s;
    cin >> s;
    
    vector dp(n, vector<mint>(n+1));
    dp[0][n-1] = 1;
    rep(i, k) {
        vector old(n, vector<mint>(n+1));
        swap(dp, old);
        
        rep(l, n)for (int r = l; r < n; ++r) {
            if (l == r) {
                dp[l][r] += old[l][r];
                if (s[i] == 'L') dp[l][r+1] -= old[l][r];
                else if (l > 0) dp[l-1][l] -= old[l][r];
            }
            else {
                mint val = old[l][r] * inv[r-l];
                if (s[i] == 'L') dp[l][l] += val;
                else dp[r][r] += val;
                dp[l][r] -= val;
            }
        }
        
        if (s[i] == 'L') {
            rep(l, n) {
                for (int r = 1; r < n; ++r) {
                    dp[l][r] += dp[l][r-1];
                }
            }
        }
        else {
            rep(r, n) {
                for (int l = n-2; l >= 0; --l) {
                    dp[l][r] += dp[l+1][r];
                }
            }
        }
    }
    
    mint ans;
    rep(l, n) {
        mint sum;
        for (int r = l; r < n; ++r) {
            sum += a[r];
            ans += sum*dp[l][r];
        }
    }
    cout << ans << '\n';
}

int main() {
    rep(i, 500) inv[i] = mint(i).inv(); 
    
    int t;
    cin >> t;
    
    while (t--) solve();
    
    return 0;
}