[题解] [CF916E] Jamie and Tree

题解

题面

如果没有换根操作就直接上树剖加线段树即可

考虑换根操作如何转化

记当前的根节点为\(root\)

子树查询和子树修改类似, 在此只讨论子树查询, 假设当前要修改的是\(u\)子树

  • \(u = rt\), 直接修改整棵树即可

  • \(rt\)\(u\)的祖先或\(rt\)\(u\)在原先为\(1\)的两棵不同子树中, 修改\(u\)子树即可

  • \(u\)\(rt\)的祖先, 先修改整棵树, 然后找出\(rt\)的儿子中子树包含\(u\)的那一个儿子, 将以这个儿子为根的子树中撤销修改操作即可

现在问题就是求两点\(u, v\)在换根后的LCA了

这个点就是以\(1\)为根时\(u, root\)的LCA, \(v, root\)的LCA, \(u, v\)的LCA中深度最大的那个点

找出后按上面修改即可

查询同理

Code

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <cstdio>
#include <vector>
#define itn int
#define reaD read
#define N 300005
using namespace std;
 
int n, m, w[N], cnt, head[N], f[N][21], sz[N], dep[N], son[N], top[N], pre[N], dfn[N], rt = 1; 
struct edge { int to, next; } e[N << 1];
struct Tree { long long sum, tag; } t[N << 2]; 
 
inline int read()
{
	int x = 0, w = 1; char c = getchar();
	while(c < '0' || c > '9') { if (c == '-') w = -1; c = getchar(); }
	while(c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar(); }
	return x * w;
}
 
inline void adde(int u, int v) { e[++cnt] = (edge) { v, head[u] }; head[u] = cnt; }
 
void dfs1(int u, int fa)
{
	f[u][0] = fa;
	dep[u] = dep[fa] + 1; sz[u] = 1; 
	for(int i = 1; i <= 20; i++)
		f[u][i] = f[f[u][i - 1]][i - 1];
	for(int i = head[u]; i; i = e[i].next)
	{
		int v = e[i].to; if(v == fa) continue;
		dfs1(v, u);
		sz[u] += sz[v]; if(sz[son[u]] < sz[v]) son[u] = v; 
	}
}
 
void dfs2(int x, int y)
{
	top[pre[dfn[x] = ++cnt] = x] = y;
	if(!son[x]) return; dfs2(son[x], y);
	for(int i = head[x]; i; i = e[i].next) if(e[i].to != son[x] && e[i].to != f[x][0]) dfs2(e[i].to, e[i].to); 
}
 
void build(int p, int l, int r)
{
	if(l == r) return (void) (t[p].sum = w[pre[l]], t[p].tag = 0);
	int mid = (l + r) >> 1;
	build(p << 1, l, mid); build(p << 1 | 1, mid + 1, r);
	t[p].sum = t[p << 1].sum + t[p << 1 | 1].sum; 
}
 
void pushdown(int p, int l, int r)
{
	if(t[p].tag)
	{
		int mid = (l + r) >> 1; 
		t[p << 1].tag += t[p].tag; t[p << 1].sum += 1ll * t[p].tag * (mid - l + 1);
		t[p << 1 | 1].tag += t[p].tag; t[p << 1 | 1].sum += 1ll * t[p].tag * (r - mid);
		t[p].tag = 0; 
	}
}
 
void modify(int p, int l, int r, int ql, int qr, int k)
{
	if(ql <= l && r <= qr) return (void) (t[p].sum += 1ll * k * (r - l + 1), t[p].tag += k); 
	pushdown(p, l, r);
	int mid = (l + r) >> 1;
	if(ql <= mid) modify(p << 1, l, mid, ql, qr, k);
	if(mid < qr) modify(p << 1 | 1, mid + 1, r, ql, qr, k);
	t[p].sum = t[p << 1].sum + t[p << 1 | 1].sum; 
}
 
int LCA(int x, int y)
{
	while(top[x] != top[y])
	{
		if(dep[top[x]] < dep[top[y]]) swap(x, y);
		x = f[top[x]][0]; 
	}
	return dep[x] < dep[y] ? x : y; 
}
 
int finds(int u, int v)
{
	if(dep[u] < dep[v]) swap(u, v);
	for(int i = 20; i >= 0; i--)
		if(dep[f[u][i]] > dep[v]) u = f[u][i];
	return u; 
}
 
long long query(int p, int l, int r, int ql, int qr)
{
	if(ql <= l && r <= qr) return t[p].sum; 
	pushdown(p, l, r); 
	int mid = (l + r) >> 1; long long ans = 0; 
	if(ql <= mid) ans += query(p << 1, l, mid, ql, qr); 
	if(mid < qr) ans += query(p << 1 | 1, mid + 1, r, ql, qr);
	t[p].sum = t[p << 1].sum + t[p << 1 | 1].sum; 
	return ans; 
}
 
int main()
{
/*	freopen("A.in", "r", stdin);
	freopen("A.out", "w", stdout); 
*/	n = read(); m = read();
	for(int i = 1; i <= n; i++) w[i] = read();
	for(int i = 1; i < n; i++)
	{
		int u = read(), v = read();
		adde(u, v); adde(v, u); 
	}
	cnt = 0; dfs1(1, 0); dfs2(1, 1);
	build(1, 1, n); 
	for(int i = 1; i <= m; i++)
	{
		int opt = read(); 
		if(opt == 1) rt = read();
		if(opt == 2)
		{
			int u = read(), v = read(), x = read(), lca = LCA(u, v); 
			if(LCA(lca, rt) == lca)
			{
				int lcau = LCA(rt, u), lcav = LCA(rt, v); 
				lcau = dep[lcau] < dep[lcav] ? lcav : lcau;
				modify(1, 1, n, 1, n, x); 
				if(rt != lcau)
				{
					int s = finds(rt, lcau);
					modify(1, 1, n, dfn[s], dfn[s] + sz[s] - 1, -x); 
				}
			}
			else modify(1, 1, n, dfn[lca], dfn[lca] + sz[lca] - 1, x); 
		}
		if(opt == 3)
		{
			int u = read(); 
			if(u == rt) printf("%I64d\n", query(1, 1, n, 1, n)); 
			else
			{
				int lca = LCA(u, rt); 
				if(lca == u)
				{
					int s = finds(rt, u); 
					printf("%I64d\n", query(1, 1, n, 1, n) - query(1, 1, n, dfn[s], dfn[s] + sz[s] - 1)); 
				}
				else printf("%I64d\n", query(1, 1, n, dfn[u], dfn[u] + sz[u] - 1)); 
			}
		}
	}
	return 0;
} 
posted @ 2019-07-14 16:29  ztlztl  阅读(339)  评论(0编辑  收藏  举报