【LOJ#6669】Nauuo and Binary Tree

题目

题目链接:https://loj.ac/p/6669
这是一道交互题。

Nauuo 是一个喜欢二叉树的女孩子。

这天,她创造了一个有 \(n\) 个节点的二叉树。节点的编号从 \(1\)\(n\),其中 \(1\) 是二叉树的根节点。

不过,她不记得这棵二叉树具体长什么样子了,她只记录了二叉树上任意两个节点之间的距离。你可以通过向她询问有关距离的信息来还原这棵二叉树,两个节点之间的距离定义为它们之间最短路上的边数。

你可以向 Nauuo 询问不超过 \(30000\) 次有关距离的信息。你只需要告诉她 \(2\sim n\) 号节点的父亲的编号就可以了。

\(n\leq 3000\)

思路

先用 \(n-1\) 次询问得出每一个点的深度。然后从小到大枚举深度。
假设我们枚举到深度 \(i\),那么我们已经构建出深度为 \(1\sim i-1\) 的点的二叉树了。我们先将这棵树 dfs 一遍并重剖。
然后依次枚举深度为 \(i\) 的点 \(x\)。先令 \(y\)\(1\) 号节点,询问 \(y\) 所在重链深度最深的节点 \(z\)\(x\) 之间的距离。根据 \(dep_x+dep_z-2dep_{\mathrm{lca}(x,y)}=dis(x,z)\),可以得到 \(\mathrm{lca}(x,z)\) 的深度。
由于这个 \(\mathrm{lca}(x,z)\) 一定在 \(y\) 所在重链上,而一条重链上的点深度两两不同,所以我们可以确定点 \(\mathrm{lca}(x,z)\)
由于这是一棵二叉树,而且 \(x\) 一定和 \(z\) 分别位于 \(\mathrm{lca}(x,z)\) 的两棵子树内,所以我们可以确定 \(x\) 的一个祖先是 \(\mathrm{lca}(x,z)\)\(z\) 所在子树外的另一棵子树。那么我们就令 \(y\) 等于这棵子树的根节点,继续上述操作直到 \(dep_y=dep_x-1\) 即可。
根据重链剖分的性质,一个点向上的轻链数量是 \(O(\log n)\) 的,所以我们操作次数是 \(O(n+n\log n)\) 的。由于树剖常数很小,所以可以通过 \(30000\) 次的限制。
时间复杂度 \(O(n^2)\)

代码

#include <bits/stdc++.h>
using namespace std;

const int N=3010;
int n,fa[N],bot[N],size[N],dep[N],ch[N][2];
vector<int> pos[N];

void dfs(int x)
{
	size[x]=1; bot[x]=x;
	if (ch[x][0])
	{
		dfs(ch[x][0]);
		size[x]+=size[ch[x][0]];
	}
	if (ch[x][1])
	{
		dfs(ch[x][1]);
		size[x]+=size[ch[x][1]];
		if (size[ch[x][1]]>size[ch[x][0]])
			swap(ch[x][0],ch[x][1]);
	}
	if (ch[x][0]) bot[x]=bot[ch[x][0]];
}

void add(int from,int to)
{
	fa[to]=from;
	if (ch[from][0]) ch[from][1]=to;
		else ch[from][0]=to;
}

void solve(int x)
{
	int y=1,dis;
	while (dep[y]!=dep[x]-1)
	{
		printf("? %d %d\n",x,bot[y]);
		fflush(stdout);
		scanf("%d",&dis);
		dis=(dep[x]+dep[bot[y]]-dis)/2;
		while (dep[y]<dis) y=ch[y][0];
		if (dep[y]==dep[x]-1) break;
		y=ch[y][1];
	}
	add(y,x);
}

int main()
{
	scanf("%d",&n);
	for (int i=2;i<=n;i++)
	{
		printf("? 1 %d\n",i);
		fflush(stdout);
		scanf("%d",&dep[i]);
		pos[dep[i]].push_back(i);
	}
	for (int i=1;i<=n;i++)
	{
		dfs(1);
		for (int j=0;j<pos[i].size();j++)
			solve(pos[i][j]);
	}
	printf("!");
	for (int i=2;i<=n;i++)
		printf(" %d",fa[i]);
	printf("\n");
	fflush(stdout);
	return 0;
}
posted @ 2020-12-31 08:50  stoorz  阅读(413)  评论(0编辑  收藏  举报