BZOJ 3531: [Sdoi2014]旅行 权值线段树 + 树链剖分

Description

 S国有N个城市,编号从1到N。城市间用N-1条双向道路连接,满足
从一个城市出发可以到达其它所有城市。每个城市信仰不同的宗教,如飞天面条神教、隐形独角兽教、绝地教都是常见的信仰。为了方便,我们用不同的正整数代表各种宗教,  S国的居民常常旅行。旅行时他们总会走最短路,并且为了避免麻烦,只在信仰和他们相同的城市留宿。当然旅程的终点也是信仰与他相同的城市。S国政府为每个城市标定了不同的旅行评级,旅行者们常会记下途中(包括起点和终点)留宿过的城市的评级总和或最大值。
    在S国的历史上常会发生以下几种事件:
”CC x c”:城市x的居民全体改信了c教;
”CW x w”:城市x的评级调整为w;
”QS x y”:一位旅行者从城市x出发,到城市y,并记下了途中留宿过的城市的评级总和;
”QM x y”:一位旅行者从城市x出发,到城市y,并记下了途中留宿过
的城市的评级最大值。
    由于年代久远,旅行者记下的数字已经遗失了,但记录开始之前每座城市的信仰与评级,还有事件记录本身是完好的。请根据这些信息,还原旅行者记下的数字。    为了方便,我们认为事件之间的间隔足够长,以致在任意一次旅行中,所有城市的评级和信仰保持不变。

Input

    输入的第一行包含整数N,Q依次表示城市数和事件数。
    接下来N行,第i+l行两个整数Wi,Ci依次表示记录开始之前,城市i的
评级和信仰。
    接下来N-1行每行两个整数x,y表示一条双向道路。
    接下来Q行,每行一个操作,格式如上所述。

Output

    对每个QS和QM事件,输出一行,表示旅行者记下的数字。

 题解:对于每一个宗教分别开一个线段树. 

下标为树剖序,权值为树上点权. 

由于宗教数目是 $O(n)$ 的,动态开点即可. 

#include <bits/stdc++.h>
#define setIO(s) freopen(s".in","r",stdin)
#define ll long long 
#define inf 100000000000 
#define maxn 500000 
#define N 200003 
using namespace std; 
namespace Seg
{
    #define lson t[x].l 
    #define rson t[x].r 
    int n, tot; 
    struct Node
    {
        int l, r;
        ll sumv, maxv; 
    }t[maxn << 2];       
    void pushup(int x)
    {
    	t[x].sumv = t[lson].sumv + t[rson].sumv; 
    	t[x].maxv = max(t[lson].maxv, t[rson].maxv); 
    }
    void ins(int &x, int l, int r, int p, ll v)
    {
    	if(!x) x = ++ tot; 
    	if(l == r) 
    	{
    		t[x].sumv = t[x].maxv = v; 
    		return ; 
    	}
    	int mid = (l + r) >> 1; 
    	if(p <= mid) ins(lson, l, mid, p, v); 
    	else ins(rson, mid + 1, r, p, v); 
    	pushup(x); 
    }
    void del(int x, int l, int r, int p)
    {
    	if(l == r) 
    	{
    		t[x].sumv = t[x].maxv = 0; 
    		return ; 
    	} 
    	int mid = (l + r) >> 1; 
    	if(p <= mid) del(lson, l, mid, p); 
    	else del(rson, mid + 1, r, p); 
    	pushup(x); 
    }
    ll query_sum(int l, int r, int x, int L, int R)
    {
    	if(!x) return 0; 
    	if(l >= L && r <= R) return t[x].sumv; 
    	ll tmp = 0;
    	int mid = (l + r) >> 1;
    	if(L <= mid) tmp += query_sum(l, mid, lson, L, R); 
    	if(R > mid) tmp += query_sum(mid + 1, r, rson, L, R); 
    	return tmp; 
    }
    ll query_max(int l, int r, int x, int L, int R)
    {
    	if(!x) return -inf; 
    	if(l >= L && r <= R) return t[x].maxv; 
    	ll tmp = -inf; 
    	int mid = (l + r) >> 1;
    	if(L <= mid) tmp = max(tmp, query_max(l, mid, lson, L, R)); 
    	if(R > mid) tmp = max(tmp, query_max(mid + 1, r, rson, L, R)); 
    	return tmp; 
    }
    #undef lson
    #undef rson
}; 
char str[10]; 
int n, Q, edges, tim; 
int hd[maxn], to[maxn << 1], nex[maxn << 1], W[maxn], C[maxn], fa[maxn], dep[maxn]; 
int ln[maxn], dfn[maxn], top[maxn], bot[maxn], siz[maxn], hson[maxn], rt[maxn]; 
void add(int u, int v)
{
    nex[++edges] = hd[u], hd[u] = edges, to[edges] = v; 
} 
void dfs1(int u, int ff)
{
    siz[u] = 1, fa[u] = ff, dep[u] = dep[ff] + 1; 
    for(int i = hd[u]; i ; i = nex[i])
    {
        int v = to[i]; 
        if(v == ff) continue; 
        dfs1(v, u); 
        siz[u] += siz[v]; 
        if(siz[hson[u]] < siz[v]) hson[u] = v; 
    }
}
void dfs2(int u, int tp)
{
    top[u] = tp, ln[++tim] = u, dfn[u] = tim; 
    Seg :: ins(rt[C[u]], 1, N, tim, 1ll*W[u]); 
    if(hson[u]) 
        dfs2(hson[u], tp), bot[u] = bot[hson[u]]; 
    else 
        bot[u] = u; 
    for(int i = hd[u]; i ; i = nex[i])
    {
        int v = to[i];
        if(v == fa[u] || v == hson[u]) continue; 
        dfs2(v, v); 
    }
}
ll _query_sum(int x, int y)
{
	int ty = C[y];      
	ll tmp = 0; 
	// y is the deeper one 
	while(top[x] ^ top[y])
	{
		if(dep[top[x]] > dep[top[y]]) swap(x, y); 
		tmp += Seg :: query_sum(1, N, rt[ty], dfn[top[y]], dfn[y]);       
		y = fa[top[y]]; 
	} 
	if(dep[x] > dep[y]) swap(x, y); 
	tmp += Seg :: query_sum(1, N, rt[ty], dfn[x], dfn[y]); 
	return tmp; 
}
ll _query_max(int x, int y)
{
	int ty = C[y]; 
	ll tmp = 0;
	while(top[x] ^ top[y]) 
	{
		if(dep[top[x]] > dep[top[y]]) swap(x, y); 
		tmp = max(tmp, Seg :: query_max(1, N, rt[ty], dfn[top[y]], dfn[y])); 
		y = fa[top[y]]; 
	}
	if(dep[x] > dep[y]) swap(x, y); 
	tmp = max(tmp, Seg :: query_max(1, N, rt[ty], dfn[x], dfn[y])); 
	return tmp; 
}
int main()
{
    // setIO("input");  
    scanf("%d%d",&n,&Q); 
    for(int i = 1;i <= n; ++i)  scanf("%d%d",&W[i],&C[i]);
    for(int i = 1, u, v; i < n; ++i)
    {
        scanf("%d%d",&u,&v), add(u, v), add(v, u); 
    } 
    Seg :: t[0].maxv = -inf; 
    dfs1(1, 0), dfs2(1, 1); 
    while(Q--)
    {
    	scanf("%s",str); 
    	int x, w, c, y; 
    	if(str[1] == 'C')
    	{
    		scanf("%d%d",&x,&c); 
    		Seg :: del(rt[C[x]], 1, N, dfn[x]); 
    		C[x] = c; 
    		Seg :: ins(rt[C[x]], 1, N, dfn[x], W[x]);        
    	}
    	if(str[1] == 'W')
    	{
    		scanf("%d%d",&x,&w); 
    		W[x] = w; 
    		Seg :: ins(rt[C[x]], 1, N, dfn[x], 1ll*W[x]);     
    	}
    	if(str[1] == 'S')
    	{
    		scanf("%d%d",&x,&y), printf("%lld\n",_query_sum(x, y)); 
    	}
    	if(str[1] == 'M')
    	{
    		scanf("%d%d",&x,&y), printf("%lld\n",_query_max(x, y)); 
    	}
    }
    return 0; 
}

  

posted @ 2019-06-06 00:17  EM-LGH  阅读(139)  评论(0编辑  收藏  举报