树上最小权链覆盖:可并堆

题目描述

给出一棵树和若干条直上直下的链,每条链有权值(非负)

用权值和尽量小的链覆盖树上所有的点

\(n,m \leq 1e6\)

分析

猫老师课件上的题。

先考虑本问题在链上的形式,即最小权区间覆盖,这个问题有一个经典做法,即使用线段树优化DP。但是我们发现,由于树的形态,这个算法在树上不具有较强的扩展性(或许是因为博主太弱了没想到)。实际上,最小权区间覆盖问题也可以使用一个堆来解决。具体算法是:把所有区间按左端点排序,维护一个小根堆,堆中存的是一些二元组\((c,d)\),其中\(c\)是关键字,表示的是可以用\(c\)的代价覆盖\([1,d]\)这段区间。从左往右扫所有区间,每扫到一个权值为\(w\)区间\([l,r]\),就把二元组\((\min_{\in \text{heap}} \{c\}+w,r)\)加入堆中,当处理完所有左端点为\(l\)的区间后,删除堆中所有\(d<l\)的二元组。因为我们选择的区间不可能存在包含关系,所以算法正确性显然。

这个算法可以通过支持合并的数据结构扩展到树上,这里我们使用左偏树实现的可并堆。在树上多个子树的堆合并前,每个堆的所有元素要加上其他堆的\(\min c\)之和。

代码

未经过对拍,不保证其正确性。

#include <bits/stdc++.h>
#define rin(i,a,b) for(register int i=(a);i<=(b);++i)
#define irin(i,a,b) for(register int i=(a);i>=(b);--i)
#define trav(i,a) for(register int i=head[a];i;i=e[i].nxt)
typedef long long LL;
using std::cin;
using std::cout;
using std::endl;

inline int read(){
	int x=0,f=1;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}

const int MAXN=1e6+5;

int n,m,ecnt,head[MAXN];
int dep[MAXN];
LL minf[MAXN];
int root[MAXN],tot;

struct Edge{
	int to,nxt;
}e[MAXN<<1];

struct path{
	int top,val;
};

std::vector<path> vec[MAXN];

struct leftist{
	int ch[2];
	int dis,dep;
	LL dat,tag;
}lt[MAXN];

inline void add_edge(int bg,int ed){
	++ecnt;
	e[ecnt].to=ed;
	e[ecnt].nxt=head[bg];
	head[bg]=ecnt;
}

void dfs1(int x,int pre,int depth){
	dep[x]=depth;
	trav(i,x){
		int ver=e[i].to;
		if(ver==pre) continue;
		dfs1(ver,x,depth+1);
	}
}

#define lc lt[x].ch[0]
#define rc lt[x].ch[1]

inline void pushtag(int x,LL _kk){
	lt[x].dat+=_kk;
	lt[x].tag+=_kk;
}

inline void pushdown(int x){
	if(!lt[x].tag) return;
	if(lc) pushtag(lc,lt[x].tag);
	if(rc) pushtag(rc,lt[x].tag);
	lt[x].tag=0;
}

int merge(int x,int y){
	if(!x||!y) return x+y;
	pushdown(x);pushdown(y);
	if(lt[x].dat>lt[y].dat) std::swap(x,y);
	rc=merge(rc,y);
	if(lt[lc].dis<lt[rc].dis) std::swap(lc,rc);
	lt[x].dis=lt[rc].dis+1;
	return x;
}

int del(int x){
	pushdown(x);
	return merge(lc,rc);
}

#undef lc
#undef rc

void dfs2(int x,int pre){
	LL temp=0;
	trav(i,x){
		int ver=e[i].to;
		if(ver==pre) continue;
		dfs2(ver,x);
		temp+=minf[ver];
	}
	rin(i,0,(int)vec[x].size()-1){
		lt[++tot]=(leftist){0,0,1,vec[x][i].top,vec[x][i].val+temp,0};
		root[x]=merge(root[x],tot);
	}
	trav(i,x){
		int ver=e[i].to;
		if(ver==pre) continue;
		pushtag(root[ver],temp-minf[ver]);
		root[x]=merge(root[x],root[ver]);
	}
	while(lt[root[x]].dep>dep[x]) root[x]=del(root[x]);
	minf[x]=lt[root[x]].dat;
}

int main(){
	n=read();
	rin(i,2,n){
		int u=read(),v=read();
		add_edge(u,v);
		add_edge(v,u);
	}
	dfs1(1,0,1);
	m=read();
	rin(i,1,m){
		int u=read(),v=read(),w=read();
		if(dep[u]<dep[v]) std::swap(u,v);
		vec[u].push_back((path){dep[v],w});
	}
	dfs2(1,0);
	printf("%lld\n",minf[1]);
	return 0;
}

/*
7
1 2
1 3
2 4
2 5
3 6
3 7
6
1 4 3
1 5 2
2 5 1
1 6 5
6 6 1
3 7 2

7
*/

posted on 2019-02-15 11:20  ErkkiErkko  阅读(530)  评论(0编辑  收藏  举报