E. Lomsat gelral

https://codeforces.com/contest/600/problem/E

题意:给一颗树,如果当前叶子为根的树中数字出现最多次数为k,求该树中所有出现次数为k的数字之和。

思路:dfs + 线段树合并。

总结:第一次接触线段树合并,整理了3个上午才整理出模板来,不知道这种线段树合并有没有区间更新的功能,这个题目是单点更新,所以板子就先写成单点的了。

/*
  线段树的节点
  重载了合并区间符号 '+'
  实现了区间上更新数值的模板函数
*/
class SegmentTreeNode {
public:
    SegmentTreeNode() {}

    /*
      重载左右孩子区间合并操作,要根据题目自己写
    */
    friend SegmentTreeNode operator + (const SegmentTreeNode& a, const SegmentTreeNode& b) {
        SegmentTreeNode res{ a };
        if (a.cnt == b.cnt) {
            res.sum = a.sum + b.sum;
        }
        else if (a.cnt < b.cnt) {
            res.cnt = b.cnt;
            res.sum = b.sum;
        }
        return res;
    }

    /*
      重载线段树叶子节点合并,要根据题意自己写
    */
    friend SegmentTreeNode operator | (const SegmentTreeNode& a, const SegmentTreeNode& b){
        SegmentTreeNode res = a;
        res.cnt += b.cnt;
        res.sum = max(res.sum, b.sum);
        return res;
    }


    /*
      区间更新值value操作,在当前的区间节点上更新数值value到该区间
    */
    template<typename T>
    void applyUpdate(T value) {
        cnt += 1;
        sum = value;
    }
    int cnt = 0;
    long long sum = 0;

};




/*
  多重线段树,用于线段树合并类操作
  该类实例化需要传参,参数是要更新的区间类型,可以为基础数据类型或者自己实现的类型,该类型与节点中的模板函数参数类型对应
  该方法完全动态分配空间(类似字典树内存管理),在时间与性能上都有着不俗的优势
*/
template<typename T>
class MultiSegmentTree {
public:
    MultiSegmentTree(int sz) : sz_(sz) {
        st_.resize(1);
        lchild_.resize(1);
        rchild_.resize(1);
        root_.resize(sz_);
    }

    /*
      更新线tree_id线段树
    */
    void update(int tree_id, int pos, T value) {
        checkNodeIndex(root_[tree_id]);
        update(root_[tree_id], 1, sz_, pos, value);
    }

    /*
      合并两棵线段树
    */
    void merge(int u, int v) {
        assert(u && v && std::max(u, v) < sz_);
        root_[u] = merge(root_[u], root_[v], 1, sz_);
    }

    /*
      需要根据问题重新手写,获取当前线段树的结果
    */
    long long get(int u) {
        assert(u < sz_);
        return st_[root_[u]].sum;
    }


private:
    int sz_;
    std::vector<SegmentTreeNode> st_;
    std::vector<int> lchild_;
    std::vector<int> rchild_;
    std::vector<int> root_;

    /*
      线段树单点更新,位置在pos,数值是value
    */
    void update(int p, int l, int r, int pos, T value) {
        if (l == r) {
            st_[p].applyUpdate(pos);
            return;
        }
        int mid = (l + r) >> 1;
        if (pos <= mid) {
            checkNodeIndex(lchild_[p]);
            update(lchild_[p], l, mid, pos, value);
        }
        else {
            checkNodeIndex(rchild_[p]);
            update(rchild_[p], mid + 1, r, pos, value);
        }

        st_[p] = st_[lchild_[p]] + st_[rchild_[p]];
    }

    /*
      合并两颗线段树
    */
    int merge(int p, int q, int l, int r) {
        if (!p || !q) {
            return p + q;
        }
        if (l == r) {
            st_[p] = st_[p] | st_[q];
            return p;
        }
        int mid = (l + r) >> 1;
        lchild_[p] = merge(lchild_[p], lchild_[q], l, mid);
        rchild_[p] = merge(rchild_[p], rchild_[q], mid + 1, r);
        st_[p] = st_[lchild_[p]] + st_[rchild_[p]];
        return p;
    }

    /*
      使用该函数完全动态分配空间,一般空间比直接分配内存节省百分之50
    */
    inline void checkNodeIndex(int& index) {
        if (index == 0) {
            index = st_.size();
            st_.emplace_back();
            lchild_.emplace_back(0);
            rchild_.emplace_back(0);
        }
    }
};



void preProcess(){

}








void solve(){
    int n;
    cin >> n;

    vector<int> a(n + 1);
    for (int i = 1; i <= n; ++i){
        cin >> a[i];
    }

    vector<vector<int>>al(n + 1);
    for (int i = 0; i < n - 1; ++i){
        int u, v;
        cin >> u >> v;
        al[u].push_back(v);
        al[v].push_back(u);
    }

    MultiSegmentTree<int> st(n + 1);
    vector<long long> ans(n + 1);

    function<void(int, int)> dfs = [&](int u, int p){
        for (const auto& v : al[u]){
            if (v != p){
                dfs(v, u);
                st.merge(u, v);
            }
        }
        st.update(u, a[u], 1);
        ans[u] = st.get(u);
    };

    dfs(1, 0);

    for (int i = 1; i <= n; ++i){
        cout << ans[i] << " \n"[i == n];
    }
}
posted @ 2024-05-11 11:23  _Yxc  阅读(12)  评论(0)    收藏  举报