P3596 [POI2015]MOD

$ \color{#0066ff}{ 题目描述 }$

给定一棵无根树,边权都是1,请去掉一条边并加上一条新边,定义直径为最远的两个点的距离,请输出所有可能的新树的直径的最小值和最大值

\(\color{#0066ff}{输入格式}\)

第一行包含一个正整数n(3<=n<=500000),表示这棵树的点数。接下来n-1行,每行包含两个正整数u,v(1<=u,v<=n),表示u与v之间有一条边。

\(\color{#0066ff}{输出格式}\)

第一行输出五个正整数k,x1,y1,x2,y2,其中k表示新树直径的最小值,x1,y1表示这种情况下要去掉的边的两端点,x2,y2表示这种情况下要加上的边的两端点。第二行输出五个正整数k,x1,y1,x2,y2,其中k表示新树直径的最大值,x1,y1表示这种情况下要去掉的边的两端点,x2,y2表示这种情况下要加上的边的两端点。若有多组最优解,输出任意一组。

\(\color{#0066ff}{输入样例}\)

6
1 2
2 3
2 4
4 5
6 5

\(\color{#0066ff}{输出样例}\)

3 4 2 2 5
5 2 1 1 6

\(\color{#0066ff}{数据范围与提示}\)

none

\(\color{#0066ff}{题解}\)

最长的情况,肯定是两个最长链拼一起,最短,就是两个最长链的中点相连

跑树形DP处理出每个点子树内的最长链和子树外的最长链,更新答案

最后两遍bfs找到端点即可

#include<bits/stdc++.h>
#define LL long long
LL in() {
	char ch; LL x = 0, f = 1;
	while(!isdigit(ch = getchar()))(ch == '-') && (f = -f);
	for(x = ch ^ 48; isdigit(ch = getchar()); x = (x << 1) + (x << 3) + (ch ^ 48));
	return x * f;
}
const int inf = 0x7fffffff;
const int maxn = 1e6 + 10;
int f[maxn], g[maxn], up[maxn], dn[maxn][3], good[maxn][2], dep[maxn], F[maxn];
// 子树最长链,子树外最长链,以其为端点的最长链,次长链,次次长链(上下),子树f的最优,次优值,父亲
int n, minx, miny, maxx, maxy;
int min = inf, max;
struct node {
	int to;
	node *nxt;
	node(int to = 0, node *nxt = NULL): to(to), nxt(nxt) {}
}*head[maxn];
std::queue<int> q;
bool vis[maxn];
void add(int from, int to) {
	head[from] = new node(to, head[from]);
}
void dfs(int x, int fa) {
	F[x] = fa;
	dep[x] = dep[fa] + 1;
	for(node *i = head[x]; i; i = i->nxt) {
		if(i->to == fa) continue;
		dfs(i->to, x);
		f[x] = std::max(f[x], f[i->to]);
		int now = dn[i->to][0] + 1;
		if(now > dn[x][0]) {  //向下的最长,次长,次次长链
			dn[x][2] = dn[x][1];
			dn[x][1] = dn[x][0];
			dn[x][0] = now;
		}
		else if(now > dn[x][1]) {
			dn[x][2] = dn[x][1];
			dn[x][1] = now;
		}
		else if(now > dn[x][2]) dn[x][2] = now;
		now = f[i->to];
		if(now > good[x][0]) {  //更新子树最优f
			good[x][1] = good[x][0];
			good[x][0] = now;
		}
		else if(now > good[x][1]) good[x][1] = now;
	}
	f[x] = std::max(f[x], dn[x][0] + dn[x][1]);    //两条链更新f
}
void dfss(int x, int fa) {
	if(x != 1) {   //更新答案
		if(max < f[x] + g[x] + 1) {
			max = f[x] + g[x] + 1;
			maxx = fa, maxy = x;
		}
		int upd = std::max(std::max(f[x], g[x]), ((f[x] + 1) >> 1) + ((g[x] + 1) >> 1) + 1);
		if(min > upd) {
			min = upd;
			minx = fa, miny = x;
		}
	}
	for(node *i = head[x]; i; i = i->nxt) {
		if(i->to == fa) continue;
		up[i->to] = std::max(up[i->to], up[x] + 1);   //向上的最长链
		g[i->to] = std::max(g[i->to], g[x]);
		int now = dn[i->to][0] + 1;
		if(now == dn[x][0]) {    //讨论
			g[i->to] = std::max(g[i->to], std::max(dn[x][1] + dn[x][2], up[x] + dn[x][1]));
			up[i->to] = std::max(up[i->to], dn[x][1] + 1);
		}
		else if(now == dn[x][1]) {
			g[i->to] = std::max(g[i->to], std::max(dn[x][0] + dn[x][2], up[x] + dn[x][0]));
			up[i->to] = std::max(up[i->to], dn[x][0] + 1);
		}
		else {
			g[i->to] = std::max(g[i->to], std::max(dn[x][0] + dn[x][1], up[x] + dn[x][0]));
			up[i->to] = std::max(up[i->to], dn[x][0] + 1);
		}
		now = f[i->to];
		if(now == good[x][0]) g[i->to] = std::max(g[i->to], good[x][1]);
		else g[i->to] = std::max(g[i->to], good[x][0]);
		dfss(i->to, x);
	}
}
int getmid(int x, int y, int len) {   //找中点
	int now = len;
	if(dep[x] < dep[y]) x ^= y ^= x ^= y;
	while(now != (len + 1) >> 1) x = F[x], now--;
	return x;
}
int bfsmin(int s) {  
	for(int i = 1; i <= n; i++) vis[i] = false;
	vis[s] = true;
	q.push(s);
	while(!q.empty()) {
		int tp = q.front(); q.pop();
		s = tp;
		for(node *i = head[tp]; i; i = i->nxt) {
			if(vis[i->to]) continue;
			if(tp == minx && i->to == miny) continue;
			if(tp == miny && i->to == minx) continue;
			q.push(i->to);
			vis[i->to] = true;
		}
	}
	return s;
}
int bfsmax(int s) {
	for(int i = 1; i <= n; i++) vis[i] = false;
	vis[s] = true;
	q.push(s);
	while(!q.empty()) {
		int tp = q.front(); q.pop();
		s = tp;
		for(node *i = head[tp]; i; i = i->nxt) {
			if(vis[i->to]) continue;
			if(tp == maxx && i->to == maxy) continue;
			if(tp == maxy && i->to == maxx) continue;
			q.push(i->to);
			vis[i->to] = true;
		}
	}
	return s;
}		
void predoit() {
	dfs(1, 0), dfss(1, 0);
}
void getmin() {
	printf("%d %d %d ", min, minx, miny);
	int x = bfsmin(minx), y = bfsmin(x);
	int s = bfsmin(miny), t = bfsmin(s);
	printf("%d %d\n", getmid(x, y, g[miny]), getmid(s, t, f[miny]));   //连中点
}
void getmax() {
	printf("%d %d %d ", max, maxx, maxy);
	printf("%d %d\n", bfsmax(maxx), bfsmax(maxy));  //连端点
}
int main() {
	n = in();
	int x, y;
	for(int i = 1; i < n; i++) x = in(), y = in(), add(x, y), add(y, x);
	predoit();
	getmin();
	getmax();
	return 0;
}
posted @ 2019-03-07 15:09  olinr  阅读(328)  评论(0编辑  收藏  举报