树上查询最大路径子段和的模板,线段树+树链剖分实现,带修

可以只使用线段树部分使其变成求区间最大字段和

template<class T>
struct PathSubSegmentOnTree {
    struct ST {
        int l, r;
        T sum;
        T lMaxSum, rMaxSum, maxSum;
        T lMinSum, rMinSum, minSum;

        ST& operator = (T x) {
            sum = x;
            lMaxSum = rMaxSum = maxSum = x;
            lMinSum = rMinSum = minSum = x;
            return *this;
        }
    };
    friend ST operator + (const ST &lp, const ST &rp) {
        ST res;
        res.l = lp.l; res.r = rp.r;
        res.sum = lp.sum + rp.sum;
        res.lMaxSum = max(lp.lMaxSum, lp.sum + rp.lMaxSum);
        res.rMaxSum = max(rp.rMaxSum, rp.sum + lp.rMaxSum);
        res.maxSum = max({lp.maxSum, rp.maxSum, lp.rMaxSum + rp.lMaxSum});

        res.lMinSum = min(lp.lMinSum, lp.sum + rp.lMinSum);
        res.rMinSum = min(rp.rMinSum, rp.sum + lp.rMinSum);
        res.minSum = min({lp.minSum, rp.minSum, lp.rMinSum + rp.lMinSum});
        return res;
    };

    struct E {
        int to, nxt;
    };
    vector<int> son, siz, fa, dfn, top, dep, head;
    vector<T> a, w;
    vector<E> edge;
    vector<ST> st;
    int tim;
    PathSubSegmentOnTree(int n) : a(n + 5), w(n + 5), son(n + 5),
     siz(n + 5), fa(n + 5), dfn(n + 5), top(n + 5), dep(n + 5),
     st(4 * n + 20), head(n + 5, -1), edge(2 * n + 5), tim(0) {}

    void AddEdge(int u, int v) {
        edge.push_back({v, head[u]});
        head[u] = edge.size() - 1;
    }

    void dfs(int u, int t) {
        dfn[u] = ++tim; a[tim] = w[u];
        top[u] = t; 
        if(!son[u]) return ;
        dfs(son[u], t);
        for(int i = head[u]; ~i; i = edge[i].nxt) {
            int v = edge[i].to;
            if(v == fa[u] || v == son[u]) continue;
            dfs(v, v);
        }
    }

    // 线段树部分
    int ls(int x) {return x << 1; }
    int rs(int x) {return x << 1 | 1; }

    void pushup(int rt) {
        st[rt] = st[ls(rt)] + st[rs(rt)];
    }

    void build(int rt, int l, int r) {
        st[rt].l = l, st[rt].r = r;
        if(l == r) {
            st[rt] = a[l];
            return ;
        }
        int mid = l + r >> 1;
        build(ls(rt), l, mid);
        build(rs(rt), mid + 1, r);
        pushup(rt);
    }

    void build() {
        build(1, 1, tim);
    }

    void update(int rt, int pos, T val) {
        int l = st[rt].l, r = st[rt].r;
        if(l == r) {
            st[rt] = val;
            return ;
        }
        int mid = l + r >> 1;
        if(pos <= mid) update(ls(rt), pos, val);
        else update(rs(rt), pos, val);
        pushup(rt);
    }

    ST segTreeQuery(int rt, int nl, int nr) {
        int l = st[rt].l, r = st[rt].r;
        if(nl <= l && r <= nr) return st[rt];
        int mid = l + r >> 1;
        if(nr <= mid) return segTreeQuery(ls(rt), nl, nr);
        if(mid < nl) return segTreeQuery(rs(rt), nl, nr);
        ST lres = segTreeQuery(ls(rt), nl, nr);
        ST rres = segTreeQuery(rs(rt), nl, nr);
        return lres + rres;
    }

    // 树链剖分部分
    void predfs(int u, int f) {
        int mxsiz = -1; siz[u] = 1;
        dep[u] = dep[f] + 1; fa[u] = f;
        for(int i = head[u]; ~i; i = edge[i].nxt) {
            int v = edge[i].to;
            if(v == f) continue;
            predfs(v, u);
            siz[u] += siz[v];
            if(siz[v] > mxsiz) {
                son[u] = v;
                mxsiz = siz[v];
            }
        }
    }


    int LCA(int x, int y) {
        while(top[x] != top[y]) {
            if(dep[top[x]] < dep[top[y]]) swap(x, y);
            x = fa[top[x]];
        }
        if(dep[x] > dep[y]) swap(x, y);
        return x;
    }

    ST treeLinkQuery(int x, int y) {
        ST res; res.l = -1;
        while(top[x] != top[y]) {
            if(dep[top[x]] < dep[top[y]]) swap(x, y);
            ST tmp = segTreeQuery(1, dfn[top[x]], dfn[x]);
            if(res.l == -1) res = tmp;
            else res = tmp + res;
            x = fa[top[x]];
        }
        if(dep[x] > dep[y]) swap(x, y);
        ST tmp = segTreeQuery(1, dfn[x], dfn[y]);
        if(res.l == -1) res = tmp;
        else res = tmp + res;
        return res;
    }

    // 查询路径<x, y>上的最大(小)连续子段和
    ST query(int x, int y) {
        ST res;
        int lca = LCA(x, y);
        // 有向路径,注意L需要翻转
        ST L = treeLinkQuery(x, lca);
        ST R = treeLinkQuery(lca, y);
        res.minSum = min(L.minSum, R.minSum);
        res.maxSum = max(L.maxSum, R.maxSum);
        if(lca != x && lca != y) {
            res.minSum = min(res.minSum, L.lMinSum + R.lMinSum - a[dfn[lca]]);
            res.maxSum = max(res.maxSum, L.lMaxSum + R.lMaxSum - a[dfn[lca]]);
        }
        return res;
    }
};

 

posted @ 2023-07-27 17:22  CECY  阅读(23)  评论(0)    收藏  举报