Hetao P10588 十载峥嵘桀骜 题解 [ 紫 ] [ 树的直径 ] [ 矩阵加速 DP ] [ 状态设计优化 ]

十载峥嵘桀骜:感觉挺简单的,就是代码处理比较繁琐。

一个最简单的部分分是暴力模拟建图之后跑矩阵快速幂转移,时间复杂度 \(O(n^3\log t)\),随便拼点其他特殊性质就能 68pts 了。

考虑正解,结合树的直径的 DFS 求法,容易注意到在走完第一天后,第二天开始走的就是树的直径了

同时又有“若树上所有边边权均为正,则树的所有直径中点重合”的一个经典结论,因此每天走的路线就是形如:\(u \to root \to v\)。其中 \(u\) 指一个在直径上的叶子,\(root\) 指直径中点,\(v\) 指另一个在直径上的叶子。为了保证每次走的是一个直径,需要保证在以 \(root\) 为根时,\(\bm{u, v}\) 不在一个子树内

由此性质可以对原来的暴力进行一点优化,把以 \(root\) 为根的每个儿子当成矩阵上的一个节点,每次转移相当于从一个儿子转移到另一个儿子。令 \(dp_i\) 表示在节点 \(i\) 方案数(要求 \(i\)\(root\) 的儿子,下文同理),\(x_i\) 表示节点 \(i\) 的子树中在直径上的叶子个数,则从节点 \(a\) 转移到节点 \(b\) 的方程为 \(dp_b\overset{+}{\leftarrow}dp_a\times x_b\)。时间复杂度实质没有改变,菊花就能把它卡飞。

继续观察性质,注意到转移方程只与 \(\bm x\) 的值有关,而 \(\sum x\le n\),所以可以从 \(x\) 的种类数考虑,发现不同的 \(\bm x\) 最多只有 \(\bm{\sqrt n}\)。因此可以将 \(x\) 相同的 \(i\) 进行合并,同时特殊处理 \(x\) 相同的节点之间的转移。此时矩阵上的节点被减少到 \(\sqrt n\) 量级,跑矩阵快速幂即可。时间复杂度 \(O(n\sqrt n\log t)\)

因为直径的中点可能在某条边上,不便于处理,于是可以在每条边上都插入一个虚点,这样就能避免中点在边上的情况。

#include <bits/stdc++.h>
#define fi first
#define se second
#define eb(x) emplace_back(x)
#define pb(x) push_back(x)
#define lc(x) (tr[x].ls)
#define rc(x) (tr[x].rs)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
using pi = pair<int, int>;
const int N = 5005, M = 105;
const ll mod = 1e9 + 7;
ll n, kcnt, troot, ksz[N], knum[N], m;
vector<int> g[N];
struct Matrix{
    ll a[M][M];
    Matrix(){ memset(a, 0, sizeof(a)); }
    Matrix operator * (const Matrix & t) const{
        Matrix res;
        for(int i = 1; i <= kcnt; i++)
            for(int k = 1; k <= kcnt; k++)
                for(int j = 1; j <= kcnt; j++)
                    res.a[i][j] = (res.a[i][j] + a[i][k] * t.a[k][j]) % mod;
        return res;
    }
};
Matrix mqpow(Matrix a, ll b)
{
    Matrix res;
    for(int i = 0; i <= kcnt; i++) res.a[i][i] = 1;
    while(b)
    {
        if(b & 1) res = res * a;
        b >>= 1;
        a = a * a;
    }
    return res;
}
ll dep[N][N], diam, depmx[N], depcnt[N], scnt[N], stot[N], tot[N];
void dfs(int u, int fa, int anc)
{
    dep[anc][u] = dep[anc][fa] + 1;
    for(auto v : g[u])
    {
        if(v == fa) continue;
        dfs(v, u, anc);
    }
}
void dfs2(int u, int fa, int anc)
{
    depmx[u] = dep[anc][u];
    depcnt[u] = 1;
    for(auto v : g[u])
    {
        if(v == fa) continue;
        dfs2(v, u, anc);
        scnt[u] = (scnt[u] + scnt[v]) % mod;
        if(depmx[v] > depmx[u])
        {
            depmx[u] = depmx[v];
            depcnt[u] = depcnt[v];
        }
        else if(depmx[v] == depmx[u]) depcnt[u] += depcnt[v];
    }
}
void solve()
{
    cin >> n >> m;
    memset(dep, 0, sizeof(dep));
    for(int i = 0; i <= 2 * n - 1; i++)
    {
        g[i].clear();
        tot[i] = ksz[i] = knum[i] = depmx[i] = depcnt[i] = scnt[i] = stot[i] = 0;
    }
    for(int i = 1; i < n; i++)
    {
        int u, v;
        cin >> u >> v;
        g[u].push_back(n + i);
        g[n + i].push_back(u);
        g[v].push_back(n + i);
        g[n + i].push_back(v);
    }
    if(n == 1)
    {
        cout << 1 << "\n";
        return;
    }
    diam = 0;
    for(int i = 1; i <= 2 * n - 1; i++)
    {
        dfs(i, 0, i);
        ll nmx = 0;
        for(int j = 1; j <= 2 * n - 1; j++)
        {
            diam = max(diam, dep[i][j]);
            nmx = max(nmx, dep[i][j]);
        }
        if(i > n) continue;
        for(int j = 1; j <= 2 * n - 1; j++)
        {
            if(j > n) continue;
            if(nmx == dep[i][j])
                scnt[j]++;
        }        
    }
    troot = 0;
    for(int i = 1; i <= 2 * n - 1; i++)
    {
        for(int j = 1; j <= 2 * n - 1; j++)
        {
            if(dep[i][j] == diam)
            {
                for(int k = 1; k <= 2 * n - 1; k++)
                {
                    if(dep[k][i] + dep[k][j] - 1 == diam && dep[k][i] == dep[k][j])
                    {
                        troot = k;
                        break;
                    }
                }
                break;
            }
        }
        if(troot) break;
    }    
    dfs2(troot, 0, troot);
    kcnt = 0;
    for(auto v : g[troot])
    {
        if(depmx[v] != depmx[troot]) continue;
        tot[depcnt[v]]++;
        stot[depcnt[v]] = (stot[depcnt[v]] + scnt[v]) % mod;
    }
    Matrix dp, sx;
    for(int i = 0; i <= 2 * n - 1; i++)
    {
        if(tot[i])
        {
            ksz[++kcnt] = tot[i];
            knum[kcnt] = i;
            sx.a[kcnt][kcnt] = stot[i];
        }
    }
    for(int i = 1; i <= kcnt; i++)
    {
        for(int j = 1; j <= kcnt; j++)
        {
            if(i == j)
                dp.a[i][j] = knum[i] * (ksz[i] - 1) % mod;
            else
                dp.a[i][j] = knum[j] * ksz[j] % mod;
        }
    }
    sx = sx * mqpow(dp, m - 1);
    ll ans = 0;
    for(int i = 1; i <= kcnt; i++)
        for(int j = 1; j <= kcnt; j++)
            ans = (ans + sx.a[i][j]) % mod;
    cout << ans << "\n";
}
int main()
{
    //freopen("sample.in", "r", stdin);
    //freopen("sample.out", "w", stdout);
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    int tid, t;
    cin >> tid >> t;
    while(t--) solve();
    return 0;
}
posted @ 2025-09-23 18:48  KS_Fszha  阅读(10)  评论(0)    收藏  举报