[2018HN省队集训D1T1] Tree

[2018HN省队集训D1T1] Tree

题意

给定一棵带点权树, 要求支持下面三种操作:

  • 1 rootroot 设为根.
  • 2 u v d 将以 \(\operatorname{LCA} (u,v)\) 为根的子树中的点权值加上 \(d\).
  • 3 u 查询以 \(u\) 为根的子树中的点的权值之和.

初始时根为 \(1\).

\(n,q\le3\times 10^5\)

时限 \(1\texttt{s}\).

题解

垃圾卡常题毁我青春

写这个题解主要是存板子的...毕竟LCT上比较科学优雅地实现LCA需要改板子...但是我的LCT长得比较滑稽没几个人写得和我一样

考场上主要思路是当成把换根子树修改和换根LCA分开算, 子树修改可以对DFS序建线段树解决. 具体做法是分类讨论新根 \(p\) /原来的根 \(r\) /要修改的子树的根 \(u\) 三个点的位置关系. 若 \(u\) 不在 \(p\)\(r\) 的路径上, 那么直接修改 \(u\) 在以 \(r\) 为根的子树即可. 否则设 \(u\rightarrow v \leadsto r\), 那么除了 \(v\) 的子树之外的所有点都要修改, 分两段解决或者先整体加再子树减也可以.

换根LCA是LCT的标准操作. 比较科学地求LCA需要在 Access 的时候返回最后一次连接的虚边的父亲侧结点, 就可以两次 Access 求出LCA了.

考场上打完没过样例发现Access的时候把虚子树连到Splay左儿子去了囧...

然后这题数据丧病地出到了 3e5 所以需要常数优化一下比如加个快读3e5的读入量还不加快读显然是自己作死吧

参考代码

#include <bits/stdc++.h>

const int MAXE=1e6+10;
const int MAXV=3e5+10;
typedef long long intEx;

struct Edge{
	int from;
	int to;
	Edge* next;
};
Edge E[MAXE];
Edge* head[MAXV];
Edge* top=E;

struct LCT{
#define lch chd[0]
#define rch chd[1]
#define kch chd[k]
#define xch chd[k^1]
	struct Node{
		int id;
		bool rev;
		Node* prt;
		Node* pprt;
		Node* chd[2];
		Node(int id):id(id),rev(false),prt(NULL),pprt(NULL),chd{NULL,NULL}{}
		inline void Flip(){
			if(this!=NULL){
				this->rev=!this->rev;
				std::swap(this->lch,this->rch);
			}
		}
		inline void PushDown(){
			if(this!=NULL&&this->rev){
				this->lch->Flip();
				this->rch->Flip();
				this->rev=false;
			}
		}
	};
	std::vector<Node*> N;
	LCT(int n):N(n+1){
		for(int i=1;i<=n;i++)
			N[i]=new Node(i);
	}
	inline void Rotate(Node* root,int k){
		Node* tmp=root->xch;
		root->PushDown();
		tmp->PushDown();
		tmp->prt=root->prt;
		if(root->prt==NULL){
			tmp->pprt=root->pprt;
			root->pprt=NULL;
		}
		else if(root->prt->lch==root)
			root->prt->lch=tmp;
		else
			root->prt->rch=tmp;
		root->xch=tmp->kch;
		if(root->xch!=NULL)
			root->xch->prt=root;
		tmp->kch=root;
		root->prt=tmp;
	}
	inline void Splay(Node* root){
		while(root->prt!=NULL){
			int k=root->prt->lch==root;
			if(root->prt->prt==NULL)
				Rotate(root->prt,k);
			else{
				int d=root->prt->prt->lch==root->prt;
				Rotate(k==d?root->prt->prt:root->prt,k);
				Rotate(root->prt,d);
			}
		}
	}
	inline void Expose(Node* root){
		Splay(root);
		root->PushDown();
		if(root->rch!=NULL){
			root->rch->prt=NULL;
			root->rch->pprt=root;
			root->rch=NULL;
		}
	}
	inline	Node* Access(Node* root){
		Expose(root);
		Node* ret=root;
		while(root->pprt!=NULL){
			ret=root->pprt;
			Expose(root->pprt);
			root->pprt->rch=root;
			root->prt=root->pprt;
			root->pprt=NULL;
			Splay(root);
		}
		return ret;
	}
	inline	void Evert(Node* root){
		Access(root);
		Splay(root);
		root->Flip();
	}
	inline	void Evert(int root){
		Evert(N[root]);
	}
	inline	void Link(int prt,int son){
		Evert(N[son]);
		N[son]->pprt=N[prt];
	}
	inline	int LCA(int x,int y){
		Access(N[x]);
		return Access(N[y])->id;
	}
#undef lch
#undef rch
#undef kch
#undef xch
};

struct Node{
	int l;
	int r;
	intEx add;
	intEx sum;
	Node* lch;
	Node* rch;
	Node(int,int);
	void Maintain();
	void PushDown();
	intEx Query(int,int);
	void Add(const intEx&);
	void Add(int,int,const intEx&);
};

int n;
int q;
int clk;
int val[MAXV];
int pos[MAXV];
int dfn[MAXV];
int deep[MAXV];
int size[MAXV];
int prt[20][MAXV];

void ReadInt(int&);
void Insert(int,int);
int Ancestor(int,int);
void DFS(int,int,int);

int main(){
	ReadInt(n);
	ReadInt(q);
	for(int i=1;i<=n;i++)
		ReadInt(val[i]);
	LCT* T=new LCT(n);
	for(int i=1;i<n;i++){
		int a,b;
		ReadInt(a);
		ReadInt(b);
		Insert(a,b);
		Insert(b,a);
		T->Link(a,b);
	}
	DFS(1,0,0);
	for(int i=1;(1<<i)<=n;i++)
		for(int j=1;j<=n;j++)
			prt[i][j]=prt[i-1][prt[i-1][j]];
	Node* N=new Node(1,n);
	T->Evert(1);
	int root=1;
	for(int i=0;i<q;i++){
		int t;
		ReadInt(t);
		if(t==1){
			ReadInt(root);
			T->Evert(root);
		}
		else if(t==2){
			int a,b,d;
			ReadInt(a);
			ReadInt(b);
			ReadInt(d);
			int lca=T->LCA(a,b);
			if(lca==root)
				N->Add(1,n,d);
			else if(deep[lca]>=deep[root])
				N->Add(dfn[lca],dfn[lca]+size[lca]-1,d);
			else if(Ancestor(root,deep[root]-deep[lca])==lca){
				int x=Ancestor(root,deep[root]-deep[lca]-1);
				N->Add(1,n,d);
				N->Add(dfn[x],dfn[x]+size[x]-1,-d);
			}
			else
				N->Add(dfn[lca],dfn[lca]+size[lca]-1,d);
		}
		else if(t==3){
			int r;
			ReadInt(r);
			intEx ans=0;
			if(r==root)
				ans=N->Query(1,n);
			else if(deep[r]>=deep[root])
				ans=N->Query(dfn[r],dfn[r]+size[r]-1);
			else if(Ancestor(root,deep[root]-deep[r])==r){
				int x=Ancestor(root,deep[root]-deep[r]-1);
				ans+=N->Query(1,n);
				ans-=N->Query(dfn[x],dfn[x]+size[x]-1);
			}
			else
				ans=N->Query(dfn[r],dfn[r]+size[r]-1);
			printf("%lld\n",ans);
		}
	}
	return 0;
}

inline int Ancestor(int cur,int k){
	for(int i=0;(1<<i)<=k;i++)
		if((1<<i)&k)
			cur=prt[i][cur];
	return cur;
}

void DFS(int root,int prt,int deep){
	::size[root]=1;
	::dfn[root]=++clk;
	::deep[root]=deep;
	::prt[0][root]=prt;
	::pos[dfn[root]]=root;
	for(Edge* i=head[root];i!=NULL;i=i->next){
		if(i->to!=prt){
			DFS(i->to,root,deep+1);
			size[root]+=size[i->to];
		}
	}
}

inline void Insert(int from,int to){
	top->from=from;
	top->to=to;
	top->next=head[from];
	head[from]=top++;
}

Node::Node(int l,int r):l(l),r(r),add(0),lch(NULL),rch(NULL){
	if(l==r)
		sum=val[pos[l]];
	else{
		int mid=(l+r)>>1;
		this->lch=new Node(l,mid);
		this->rch=new Node(mid+1,r);
		this->sum=this->lch->sum+this->rch->sum;
	}
}

void Node::Add(int l,int r,const intEx& d){
	if(l<=this->l&&this->r<=r)
		this->Add(d);
	else{
		this->PushDown();
		if(l<=this->lch->r)
			this->lch->Add(l,r,d);
		if(this->rch->l<=r)
			this->rch->Add(l,r,d);
		this->Maintain();
	}
}

intEx Node::Query(int l,int r){
	if(l<=this->l&&this->r<=r)
		return this->sum;
	else{
		this->PushDown();
		if(r<=this->lch->r)
			return this->lch->Query(l,r);
		if(this->rch->l<=l)
			return this->rch->Query(l,r);
		return this->lch->Query(l,r)+this->rch->Query(l,r);
	}
}

void Node::Maintain(){
	this->sum=this->lch->sum+this->rch->sum;
}

void Node::PushDown(){
	if(this->add){
		this->lch->Add(this->add);
		this->rch->Add(this->add);
		this->add=0;
	}
}

inline void Node::Add(const intEx& d){
	this->add+=d;
	this->sum+=d*(r-l+1);
}

inline void ReadInt(int& target){
	target=0;
	int sgn=1;
	register char ch=getchar();
	while(!isdigit(ch)){
		if(ch=='-')
			sgn=-sgn;
		ch=getchar();
	}
	while(isdigit(ch)){
		target=target*10+ch-'0';
		ch=getchar();
	}
	target*=sgn;
}

posted @ 2019-02-28 15:50  rvalue  阅读(335)  评论(0编辑  收藏  举报