树型dp2

[Algo] 树型dp2

1. 到达首都的最少油耗

// 1. 到达首都的最少油耗
// https://leetcode.cn/problems/minimum-fuel-cost-to-report-to-the-capital/description/
void func1(vector<vector<int>>& graph, int seats, vector<long long>& cost, vector<int>& size, int cur, int parent) {
    size[cur]++;
    for (int adj : graph[cur]) 
    {
        if (adj != parent) 
        {
            func1(graph, seats, cost, size, adj, cur);
            size[cur] += size[adj];
            cost[cur] += cost[adj];
            cost[cur] += (size[adj] + seats - 1) / seats;
        }
    }
}
long long minimumFuelCost(vector<vector<int>>& roads, int seats) {
    int n = roads.size() + 1;
    vector<vector<int>> graph(n);
    for (int i = 0; i < n - 1; i++)
    {
        graph[roads[i][0]].push_back(roads[i][1]);
        graph[roads[i][1]].push_back(roads[i][0]);
    }
    vector<long long> cost(n);
    vector<int> size(n);
    func1(graph, seats, cost, size, 0, -1);
    return cost[0];
}

2. 相邻字符不同的最长路径

// 2. 相邻字符不同的最长路径
// https://leetcode.cn/problems/longest-path-with-different-adjacent-characters/description/
struct Info {
    int maxPath;
    int maxPathFromHead;
    Info(int a, int b) : maxPath(a), maxPathFromHead(b) {}
};
Info func2(vector<vector<int>>& graph, string& s, int cur) {
    if (graph[cur].empty()) return Info(1, 1);
    int max1 = 0, max2 = 0, maxPath = 0;
    for (int child : graph[cur]) {
        Info nextInfo = func2(graph, s, child);
        maxPath = max(maxPath, nextInfo.maxPath);
        if (s[child] != s[cur]) {
            if (nextInfo.maxPathFromHead > max1) {
                max2 = max1;
                max1 = nextInfo.maxPathFromHead;
            } else if (nextInfo.maxPathFromHead > max2) {
                max2 = nextInfo.maxPathFromHead;
            }
        }
    }
    maxPath = max(maxPath, max1 + max2 + 1);
    return Info(maxPath, max1 + 1);
}
int longestPath(vector<int>& parent, string s) {
    int n = parent.size();
    vector<vector<int>> graph(n);
    for (int i = 1; i < n; i++) graph[parent[i]].push_back(i);
    return func2(graph, s, 0).maxPath;
}

3. 移除子树后的二叉树高度

// 3. 移除子树后的二叉树高度
// https://leetcode.cn/problems/height-of-binary-tree-after-subtree-removal-queries/description/ 
int dfn[100001], size[100001], depth[100001], cnt = 0;
int maxl[100000], maxr[100002];
void func3(TreeNode* root, int k) {
    dfn[root->val] = ++cnt;
    depth[cnt] = k;
    size[cnt] = 1;
    if (root->left) {
        func3(root->left, k + 1);
        size[dfn[root->val]] += size[dfn[root->left->val]];
    }
    if (root->right) {
        func3(root->right, k + 1);
        size[dfn[root->val]] += size[dfn[root->right->val]];
    }
}
vector<int> treeQueries(TreeNode* root, vector<int>& queries) {
    func3(root, 0);
    for (int i = 1; i < cnt; i++) maxl[i] = max(maxl[i - 1], depth[i]);
    for (int i = cnt; i > 1; i--) maxr[i] = max(maxr[i + 1], depth[i]);
    vector<int> answer;
    for (int query : queries) {
        int max1 = maxl[dfn[query] - 1];
        int max2 = maxr[dfn[query] + size[dfn[query]]];
        answer.push_back(max(max1, max2));    
    }
    return answer;
}

4. 从树中删除边的最小分数

// 4. 从树中删除边的最小分数
// https://leetcode.cn/problems/minimum-score-after-removals-on-a-tree/
int dfn[1000], size[1001], xorSum[1001], cnt = 0;
void func4(vector<int>& nums, vector<vector<int>>& graph, int cur) {
    dfn[cur] = ++cnt;
    size[cnt] = 1;
    xorSum[cnt] = nums[cur];
    for (int adj : graph[cur]) {
        if (dfn[adj] == 0) {
            func4(nums, graph, adj);
            size[dfn[cur]] += size[dfn[adj]];
            xorSum[dfn[cur]] ^= xorSum[dfn[adj]];
        }
    }
}
int minimumScore(vector<int>& nums, vector<vector<int>>& edges) {
    int n = nums.size();
    vector<vector<int>> graph(n);
    for (int i = 0; i < n - 1; i++) {
        graph[edges[i][0]].push_back(edges[i][1]);
        graph[edges[i][1]].push_back(edges[i][0]);
    }
    func4(nums, graph, 0);
    int ans = INT32_MAX;
    for (int i = 0; i < n - 1; i++) {
        int a = max(dfn[edges[i][0]], dfn[edges[i][1]]);
        for (int j = 0; j < i; j++) {
            int b = max(dfn[edges[j][0]], dfn[edges[j][1]]);
            int pre = a <= b ? a : b, post = a > b : a : b;
            int sum1, sum2, sum3;
            if (post < pre + size[pre]) {
                sum3 = xorSum[post];
                sum2 = xorSum[pre] ^ xorSum[post];
                sum1 = xorSum[1] ^ xorSum[pre];
            } else {
                sum3 = xorSum[post];
                sum2 = xorSum[pre];
                sum1 = xorSum[1] ^ xorSum[pre] ^ xorSum[post];
            }
            ans = min(ans, max(sum1, max(sum2, sum3)) - min(sum1, min(sum2, sum3)));
        }
    }
    return ans;
}
posted @ 2025-03-22 17:02  yaoguyuan  阅读(8)  评论(0)    收藏  举报