【bzoj1095】[ZJOI2007]Hide 捉迷藏 动态点分治+堆

题目描述

捉迷藏 Jiajia和Wind是一对恩爱的夫妻,并且他们有很多孩子。某天,Jiajia、Wind和孩子们决定在家里玩捉迷藏游戏。他们的家很大且构造很奇特,由N个屋子和N-1条双向走廊组成,这N-1条走廊的分布使得任意两个屋子都互相可达。游戏是这样进行的,孩子们负责躲藏,Jiajia负责找,而Wind负责操纵这N个屋子的灯。在起初的时候,所有的灯都没有被打开。每一次,孩子们只会躲藏在没有开灯的房间中,但是为了增加刺激性,孩子们会要求打开某个房间的电灯或者关闭某个房间的电灯。为了评估某一次游戏的复杂性,Jiajia希望知道可能的最远的两个孩子的距离(即最远的两个关灯房间的距离)。 我们将以如下形式定义每一种操作: C(hange) i 改变第i个房间的照明状态,若原来打开,则关闭;若原来关闭,则打开。 G(ame) 开始一次游戏,查询最远的两个关灯房间的距离。

输入

第一行包含一个整数N,表示房间的个数,房间将被编号为1,2,3…N的整数。接下来N-1行每行两个整数a, b,表示房间a与房间b之间有一条走廊相连。接下来一行包含一个整数Q,表示操作次数。接着Q行,每行一个操作,如上文所示。

输出

对于每一个操作Game,输出一个非负整数到hide.out,表示最远的两个关灯房间的距离。若只有一个房间是关着灯的,输出0;若所有房间的灯都开着,输出-1。

样例输入

8
1 2
2 3
3 4
3 5
3 6
6 7
6 8
7
G
C 1
G
C 2
G
C 1
G

样例输出

4
3
3
4


题解

动态点分治 +堆

动态点分治:将点分治的上一层重心与下一层连边,可以得到一棵新树(点分树)。由于每次都是找重心,所以树高不超过$\log$,就可以使用各种数据结构维护各种子树信息。

考虑如果本题是静态的,只有一次查询该怎么做:求出以每个点为根的最长路径,即求 $|$所有节点的 $|$子节点的 $|$子树中的节点到父亲节点的最大值$|$ 的最大值和次大值的和$|$ 的最大值$|$。($|$为断句方法= =)

形象一点,求每个点的子树中的所有节点到父亲节点的距离的最大值$p1$;每个点求出它所有子节点的$p1$以及当前节点状态(存在则为0)中的最大的和次大的,加起来得到$p2$;所有节点的$p2$的最大值就是$p3$。

考虑带修改,多次查询:首先由于有修改,所以树高必须要有保证,所以选择动态树分治的点分树结构。

那么需要是用数据结构,支持查询最大值和次大值,使用3种堆:

$s1[]$:维护一个子树中所有节点到当前点的父亲节点的距离;

$s2[]$:维护一个点的所有子分治节点(点分树中子节点)的$s1$中的最大值,如果当前节点可用,则需要再增加一个$0$;

$s3$:维护所有节点的$s2$的最大值与次大值(如果存在)之和。

每次$s3$的最大值就是答案。

于是就可以自底向上修改路径上的$s1$和$s2$,并修改$s3$。具体实现较为复杂:需要消除下一级对上一级的影响,所以要先删除上一级,再插入上一级;需要实现可以删除的堆,于是需要维护两个堆,删除时将要删的数加入到辅助堆中,每次取堆顶时如果两堆堆顶相同则都弹出。

并且需要维护欧拉遍历序并使用RMQLCA支持$O(1)$查询LCA以保证时间复杂度。

总时间复杂度为$O(n\log^2n)$,空间复杂度为$O(n\log n)$。

#include <queue>
#include <cstdio>
#define N 100010
using namespace std;
struct heap
{
	priority_queue<int> A , B;
	void push(int x) {A.push(x);}
	void del(int x) {B.push(x);}
	int top()
	{
		while(!B.empty() && A.top() == B.top()) A.pop() , B.pop();
		return A.top();
	}
	int sum()
	{
		int a = top(); A.pop();
		int b = top(); push(a);
		return a + b;
	}
	int size() {return A.size() - B.size();}
}s1[N] , s2[N] , s3;
int head[N] , to[N << 1] , next[N << 1] , cnt , vis[N] , deep[N] , pos[N] , md[20][N << 1];
int si[N] , mx[N] , sum , root , log[N << 1] , tot , fa[N] , val[N];
char str[5];
void insert(heap &s) {if(s.size() >= 2) s3.push(s.sum());}
void erase(heap &s) {if(s.size() >= 2) s3.del(s.sum());}
void add(int x , int y)
{
	to[++cnt] = y , next[cnt] = head[x] , head[x] = cnt;
}
void dfs(int x , int fa)
{
	int i;
	md[0][++tot] = deep[x] , pos[x] = tot;
	for(i = head[x] ; i ; i = next[i])
		if(to[i] != fa)
			deep[to[i]] = deep[x] + 1 , dfs(to[i] , x) , md[0][++tot] = deep[x];
}
int lca(int x , int y)
{
	x = pos[x] , y = pos[y];
	if(x > y) swap(x , y);
	int k = log[y - x + 1];
	return min(md[k][x] , md[k][y - (1 << k) + 1]);
}
void getroot(int x , int fa)
{
	int i;
	si[x] = 1 , mx[x] = 0;
	for(i = head[x] ; i ; i = next[i])
		if(!vis[to[i]] && to[i] != fa)
			getroot(to[i] , x) , si[x] += si[to[i]] , mx[x] = max(mx[x] , si[to[i]]);
	mx[x] = max(mx[x] , sum - si[x]);
	if(mx[x] < mx[root]) root = x;
}
void solve(int x)
{
	int i;
	vis[x] = 1;
	for(i = head[x] ; i ; i = next[i])
		if(!vis[to[i]])
			sum = si[to[i]] , root = 0 , getroot(to[i] , 0) , fa[root] = x , solve(root);
}
void join(int x)
{
	erase(s2[x]) , s2[x].push(0) , insert(s2[x]);
	int t;
	for(t = x ; fa[t] ; t = fa[t])
	{
		erase(s2[fa[t]]);
		if(s1[t].size()) s2[fa[t]].del(s1[t].top());
		s1[t].push(deep[fa[t]] + deep[x] - 2 * lca(fa[t] , x)) , s2[fa[t]].push(s1[t].top());
		insert(s2[fa[t]]);
	}
}
void remove(int x)
{
	erase(s2[x]) , s2[x].del(0) , insert(s2[x]);
	int t;
	for(t = x ; fa[t] ; t = fa[t])
	{
		erase(s2[fa[t]]);
		s2[fa[t]].del(s1[t].top()) , s1[t].del(deep[fa[t]] + deep[x] - 2 * lca(fa[t] , x));
		if(s1[t].size()) s2[fa[t]].push(s1[t].top());
		insert(s2[fa[t]]);
	}
}
int main()
{
	int n , m , i , j , x , y , num;
	scanf("%d" , &n) , num = n;
	for(i = 1 ; i < n ; i ++ ) scanf("%d%d" , &x , &y) , add(x , y) , add(y , x);
	dfs(1 , 0);
	for(i = 2 ; i <= tot ; i ++ ) log[i] = log[i >> 1] + 1;
	for(i = 1 ; (1 << i) <= tot ; i ++ )
		for(j = 1 ; j <= tot - (1 << i) + 1 ; j ++ )
			md[i][j] = min(md[i - 1][j] , md[i - 1][j + (1 << (i - 1))]);
	mx[0] = 1 << 30 , sum = n , getroot(1 , 0) , solve(root);
	for(i = 1 ; i <= n ; i ++ ) val[i] = 1 , join(i);
	scanf("%d" , &m);
	while(m -- )
	{
		scanf("%s" , str);
		if(str[0] == 'G')
		{
			if(num >= 2) printf("%d\n" , s3.top());
			else printf("%d\n" , num - 1);
		}
		else
		{
			scanf("%d" , &x);
			if(val[x]) num -- , val[x] = 0 , remove(x);
			else num ++ , val[x] = 1 , join(x);
		}
	}
	return 0;
}

 

posted @ 2017-08-31 10:43  GXZlegend  阅读(565)  评论(1编辑  收藏  举报