lca倍增法+求两点的最短长度

class Solution {
public:
    using ll = long long;  // 定义长整型别名,方便后续使用
    const int mod = 1e9+7; // 定义取模常量,避免重复书写
    
    // 快速幂函数,计算a的b次幂对mod取模的结果
    ll ksm(ll a, ll b) {
        ll res = 1;
        while (b) {
            if (b & 1) res = (res * a) % mod;  // 若当前位为1,则乘上当前的a
            a = (a * a) % mod;  // a自乘并取模
            b >>= 1;  // 右移一位,相当于除以2
        }
        return res;
    }
    
    // 主函数:为树的边分配权重,并处理查询
    vector<int> assignEdgeWeights(vector<vector<int>>& edges, vector<vector<int>>& queries) {
        int n = edges.size();  // 边的数量
        // 邻接表存储树的结构,fa数组存储每个节点的祖先信息
        vector<vector<int>> gh(n+2), fa(n+2, vector<int>(20, -1));
        vector<int> dep(n+2);  // 存储每个节点的深度
        
        // 构建邻接表
        for(auto& q : edges) {
            int u = q[0], v = q[1];
            gh[u].push_back(v);
            gh[v].push_back(u);
        }
        
        // 深度优先搜索函数,用于初始化节点深度和祖先数组
        auto dfs = [&](auto&& dfs, int u, int father) -> void {
            dep[u] = (father == -1) ? 0 : dep[father] + 1;  // 根节点深度为0,其余节点深度为父节点深度+1
            fa[u][0] = father;  // 直接父节点
            // 预处理倍增数组,fa[u][i]表示u的第2^i个祖先
            for(int i = 1; i < 20; i++) {
                if(fa[u][i-1] != -1) fa[u][i] = fa[fa[u][i-1]][i-1];
            }
            // 递归处理所有子节点
            for(auto& nb : gh[u]) {
                if(nb != father) dfs(dfs, nb, u); 
            }
        };
        
        dfs(dfs, 1, -1);  // 从节点1开始DFS,根节点的父节点设为-1
        
        vector<int> ans;  // 存储查询结果
        
        // 处理每个查询
        for(auto& q : queries) {
            int u = q[0], v = q[1];
            // 确保u的深度不小于v
            if(dep[u] < dep[v]) swap(u, v);
            
            // 提升u到v的同一深度
            for(int i = 19; i >= 0; i--) {
                if(fa[u][i] != -1 && dep[fa[u][i]] >= dep[v]) u = fa[u][i];
            }
            
            // 如果u和v相遇,说明LCA是u(或v)
            if(u == v) {
                int len = dep[q[0]] + dep[q[1]] - 2 * dep[u];  // 路径长度
                if(len == 0) ans.push_back(0);  // 特殊情况处理
                else ans.push_back(ksm(2, len-1));  // 计算结果并加入答案
                continue;
            }
            
            // 同时提升u和v,找到它们的LCA
            for(int i = 19; i >= 0; i--) {
                if(fa[u][i] != -1 && fa[v][i] != -1 && fa[u][i] != fa[v][i]) {
                    u = fa[u][i];
                    v = fa[v][i];
                }
            }
            
            // LCA是fa[u][0](或fa[v][0])
            int len = dep[q[0]] + dep[q[1]] - 2 * (dep[fa[u][0]]);
            ans.push_back(ksm(2, len-1));  // 计算结果并加入答案
        }
        
        return ans;  // 返回所有查询的结果
    }
};
posted @ 2025-05-28 18:20  Qacter  阅读(10)  评论(0)    收藏  举报