【XSY1294】sub 树链剖分

题目描述

  给你一棵\(n\)个点的无根树,节点\(i\)有权值\(v_i\)。现在有\(m\)次操作,操作有如下两种:

   \(1~x~y\):把\(v_x\)改成\(y\)

   \(2\):选择一个连通块(也可以不选择),使得点权和最大。输出这个点权和。

  \(n,m\leq {10}^5,|v_i|,|y|\leq 1000\)

题解

  考虑暴力,\(dp_x=v_x+\sum_{y\text{是}x\text{的儿子}}\max(dp_y,0)\)

  还记得如果是一条链的问题要怎么做吗?没错,就是线段树。每个区间维护四个信息:

   sum表示整个区间的权值和

   ls表示以从左边界往右的连续段最大权值和

   rs表示以从右边界往左的连续段最大权值和

   s表示整个区间中连续段最大权值和

  现在这是一棵树,也可以用类似的方法处理。

  对于一条重链,每个点的权值就是这个点的\(v\)值加上轻儿子的\(dp\)值。那么这条链的答案就是线段树对应区间的连续段最大权值和。

  

  修改权值的时候,这条链对这条链顶部的轻链上方那个节点的贡献就是修改后从最上面往下的答案\(-\)修改前从最上面往下的答案。

  但是线段树上一个区间在原树上不一定是联通的。我们可以在两条不同的重链之间加一个点,权值为\(-\infty\)。这样所有的答案就不会跨过这个不存在的点了。

  时间复杂度:\(O(m\log^2 n)\)

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<list>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
ll dp[200010];
int c[200010];
namespace seg
{
	struct pp
	{
		ll ls,rs,s,sum;
	};
	pp operator +(pp a,pp b)
	{
		pp c;
		c.sum=a.sum+b.sum;
		c.ls=max(a.ls,a.sum+b.ls);
		c.rs=max(b.rs,b.sum+a.rs);
		c.s=max(max(a.s,b.s),a.rs+b.ls);
		return c;
	}
	pp s[500010];
	int ls[500010];
	int rs[500010];
	int cnt;
	int build(int l,int r)
	{
		int p=++cnt;
		if(l==r)
		{
			s[p].sum=c[l];
			s[p].ls=s[p].rs=s[p].s=max(0ll,s[p].sum);
			return p;
		}
		int mid=(l+r)>>1;
		ls[p]=build(l,mid);
		rs[p]=build(mid+1,r);
		s[p]=s[ls[p]]+s[rs[p]];
		return p;
	}
	void change(int p,int x,ll v,int L,int R)
	{
		if(L==R)
		{
			s[p].sum+=v;
			s[p].ls=s[p].rs=s[p].s=max(0ll,s[p].sum);
			return;
		}
		int mid=(L+R)>>1;
		if(x<=mid)
			change(ls[p],x,v,L,mid);
		else
			change(rs[p],x,v,mid+1,R);
		s[p]=s[ls[p]]+s[rs[p]];
	}
	pp query(int p,int l,int r,int L,int R)
	{
		if(l<=L&&r>=R)
			return s[p];
		int mid=(L+R)>>1;
		if(r<=mid)
			return query(ls[p],l,r,L,mid);
		if(l>mid)
			return query(rs[p],l,r,mid+1,R);
		return query(ls[p],l,r,L,mid)+query(rs[p],l,r,mid+1,R);
	}
};
struct p
{
	int s,ms,f,d,t,w;
};
int inf=0x7fffffff;
p a[100010];
list<int> l[100010];
int v[100010];
int ti;
void dfs(int x,int f,int d)
{
	a[x].f=f;
	a[x].d=d;
	a[x].s=1;
	dp[x]=v[x];
	int sz=0;
	for(auto v:l[x])
		if(v!=f)
		{
			dfs(v,x,d+1);
			dp[x]+=max(dp[v],0ll);
			a[x].s+=a[v].s;
			if(a[v].s>sz)
			{
				sz=a[v].s;
				a[x].ms=v;
			}
		}
}
void dfs2(int x,int t)
{
	a[x].t=t;
	a[x].w=++ti;
	if(a[x].ms)
	{
		dp[x]-=max(dp[a[x].ms],0ll);
		dfs2(a[x].ms,t);
		for(auto v:l[x])
			if(v!=a[x].f&&v!=a[x].ms)
			{
				c[++ti]=-inf;
				dfs2(v,v);
			}
	}
	c[a[x].w]=dp[x];
}
void change(int x,int y)
{
	y-=v[x];
	v[x]+=y;
	for(;x;x=a[a[x].t].f)
	{
		int t=a[a[x].t].w;
		int v1=seg::query(1,t,ti,1,ti).ls;
		seg::change(1,a[x].w,y,1,ti);
		int v2=seg::query(1,t,ti,1,ti).ls;
		y=v2-v1;
	}
}
int main()
{
	freopen("a.in","r",stdin);
	freopen("a.out","w",stdout);
	seg::cnt=0;
	int n,m;
	scanf("%d%d",&n,&m);
	int i;
	for(i=1;i<=n;i++)
		scanf("%d",&v[i]);
	int x,y;
	for(i=1;i<n;i++)
	{
		scanf("%d%d",&x,&y);
		l[x].push_back(y);
		l[y].push_back(x);
	}
	ti=0;
	dfs(1,0,1);
	dfs2(1,1);
	seg::build(1,ti);
	for(i=1;i<=m;i++)
	{
		scanf("%d",&x);
		if(x==1)
		{
			scanf("%d%d",&x,&y);
			change(x,y);
		}
		else
			printf("%lld\n",seg::s[1].s);
	}
	return 0;
}
posted @ 2018-03-05 20:54  ywwyww  阅读(186)  评论(2编辑  收藏  举报