HDU4871 Shortest-path tree(树分治)

好久没做过树分治的题了,对上一次做是在南京赛里跪了一道很裸的树分治题后学的一道,多校的时候没有看这道题,哪怕看了感觉也看不出来是树分治,看出题人给了解题报告里写了树分治就做一下好了。

题意其实就是给你一个图,然后让你转换成一棵树,这棵树满足的是根节点1到其余各点的距离都是图里的最短距离,而且为了保证这棵树的唯一性,路径也必须是最小的。转化成树的方法其实就是跑一次spfa。spfa的时候记下所有到这个的前驱的边,然后这些边集反向的边补上就是构成所有最短路的边。然后在这些边上跑一次dfs,跑前将边按照到达点的序号由小到大排序,注意dfs搜的下一个点的距离必须是最短的才搜,不然的话搜出来的图就是不对的,比划一下题目给的样例就知道了。

至此图的部分转化完了,剩下的就是求一个图里包含了k个点的路径的最长距离,以及有多少条,相似的问题还有有多少条路径的乘积=k,有多少条路径的和>k,有多少条路径的乘积是完全立方数。。。做法就是典型的树分治。

树分治在《挑战程序设计竞赛》这本书上有一个很好的框架可以直接抄,我就直接拿来用了。具体的做法是找出重心,对重心外的部分递归求解,合并的时候枚举到重心的所有路径,枚举的时候可以用一个全局的map ds记录当前到达这个点的所有情况,然后用一个tds去枚举新的部分的路径,然后通过ds和tds更新答案,更新完后将tds的内容加进去ds。下面贴一记代码好了

#pragma warning(disable:4996)
#include <iostream>
#include <cstring>
#include <string>
#include <vector>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <queue>
#include <map>
using namespace std;

#define ll long long
#define maxn 31000
#define maxm 61000
#define MP make_pair

struct Edge{
	int v, w;
	Edge(int vi, int wi) :v(vi), w(wi){}
	Edge(){}
	bool operator < (const Edge & b) const{
		return v < b.v;
	}
};

vector<Edge> G[maxn];
vector<Edge> E[maxn];
vector<Edge> EE[maxn];
vector<Edge> T[maxn];

int n, m, k;

int d[maxn];
int dx[maxn];
bool in[maxn];

void dfs(int u,int dis)
{
	in[u] = true; dx[u] = dis;
	if (dx[u] != d[u]) puts("fuck");
	for (int i = 0; i < EE[u].size(); i++){
		int v = EE[u][i].v, w = EE[u][i].w;
		if (!in[v]&&w+dis==d[v]) {
			T[u].push_back(Edge(v, w));
			T[v].push_back(Edge(u, w));
			dfs(v, w + dis);
		}
	}
}

void spfa()
{
	queue<int> que;
	memset(in, 0, sizeof(in));
	memset(d, 0x3f, sizeof(d));
	d[1] = 0; in[1] = true; que.push(1);
	while (!que.empty()){
		int u = que.front(); que.pop(); in[u] = false;
		for (int i = 0; i < G[u].size(); i++){
			int v = G[u][i].v, w = G[u][i].w;
			if (d[u] + w < d[v]){
				d[v] = d[u] + w;
				if (!in[v]) {
					in[v] = true; que.push(v);
				}
				E[v].clear(); E[v].push_back(Edge(u, w));
			}
			else if (d[u] + w == d[v]){
				E[v].push_back(Edge(u, w));
			}
		}
	}
	for (int i = 1; i <= n; i++){
		for (int j = 0; j < E[i].size(); j++){
			EE[E[i][j].v].push_back(Edge(i, E[i][j].w));
			EE[i].push_back(E[i][j]);
		}
	}
	for (int i = 1; i <= n; i++) sort(EE[i].begin(), EE[i].end());
	memset(in, 0, sizeof(in));
	memset(dx, 0x3f, sizeof(dx));
	dfs(1,0);
}

bool centroid[maxn];
int ssize[maxn];

int compute_subtree_size(int v, int p){
	int c = 1;
	for (int i = 0; i < T[v].size(); i++){
		int  w = T[v][i].v;
		if (w == p || centroid[w]) continue;
		c += compute_subtree_size(w, v);
	}
	ssize[v] = c;
	return c;
}

pair<int, int> search_centroid(int v, int p, int t){
	pair<int, int> res = MP(INT_MAX, -1);
	int s = 1, m = 0;
	for (int i = 0; i < T[v].size(); i++){
		int w = T[v][i].v;
		if (w == p || centroid[w]) continue;

		res = min(res, search_centroid(w, v, t));

		m = max(m, ssize[w]);
		s += ssize[w];
	}
	m = max(m, t - s);
	res = min(res, MP(m, v));
	return res;
}

map<int, pair<int, int> > ds;
map<int, pair<int, int> > tds;
map<int, pair<int, int> >::iterator it;
map<int, pair<int, int> >::iterator itt;
// pass kk points, distant is dis
void enumerate(int v, int p, int kk, int dis, map<int, pair<int, int> > &tds)
{
	if (kk > k) return;
	it = tds.find(kk);
	if (it!=tds.end()){
		if (it->second.first == dis) {
			it->second.second += 1;
		}
		else if(it->second.first<dis){
			tds.erase(it);
			tds.insert(MP(kk, MP(dis, 1)));
		}
	}
	else{
		tds.insert(MP(kk, MP(dis, 1)));
	}
	for (int i = 0; i < T[v].size(); i++){
		int w = T[v][i].v;
		if (w == p || centroid[w]) continue;
		enumerate(w, v, kk + 1, dis + T[v][i].w, tds);
	}
}

ll ans, num;

void solve(int v)
{
	compute_subtree_size(v, -1);
	int s = search_centroid(v, -1, ssize[v]).second;
	centroid[s] = true;
	for (int i = 0; i < T[s].size(); i++){
		if (centroid[T[s][i].v]) continue;
		solve(T[s][i].v);
	}
	ds.clear();
	ds.insert(MP(1, MP(0, 1)));
	for (int i = 0; i < T[s].size(); i++){
		if (centroid[T[s][i].v]) continue;
		tds.clear();
		enumerate(T[s][i].v, s, 1, T[s][i].w, tds);
		it = tds.begin();
		while (it != tds.end()){
			int kk = it->first;
			if (ds.count(k - kk)){
				itt = ds.find(k - kk);
				int ldis = it->second.first + itt->second.first;
				if (ldis>ans) {
					ans = ldis; num = it->second.second*itt->second.second;
				}
				else if (ldis == ans){
					num += it->second.second*itt->second.second;
				}
			}
			++it;
		}
		it = tds.begin();
		while (it != tds.end()){
			int kk = it->first + 1;
			if (ds.count(kk)){
				itt = ds.find(kk);
				if (it->second.first > itt->second.first){
					ds.erase(itt);
					ds.insert(MP(kk, it->second));
				}
				else if (it->second.first == itt->second.first) itt->second.second += it->second.second;
			}
			else{
				ds.insert(MP(kk, it->second));
			}
			++it;
		}
	}
	centroid[s] = false;
}

int main()
{
	int TE; cin >> TE;
	while (TE--){
		scanf("%d%d%d", &n, &m, &k);
		for (int i = 0; i <= n; i++) {
			G[i].clear(); E[i].clear(); EE[i].clear(); T[i].clear();
		}
		int ui, vi, wi;
		for (int i = 0; i < m; i++){
			scanf("%d%d%d", &ui, &vi, &wi);
			G[ui].push_back(Edge(vi, wi));
			G[vi].push_back(Edge(ui, wi));
		}
		spfa();
		ans = 0, num = 0;
		solve(1);
		cout << ans << " " << num << endl;
	}
	return 0;
}
posted @ 2014-07-23 19:34  chanme  阅读(863)  评论(1编辑  收藏  举报