【UOJ #205】【APIO 2016】Fireworks

http://uoj.ac/problem/205
好神的题啊。
dp[i][j]表示以i为根的子树调整成长度j需要的最小代价。
首先要观察到dp值是一个下凸壳。
因为从儿子合并到父亲时要把所有儿子的凸壳相加,得到的还是一个凸壳。
父亲要把它连向它父亲的边的影响加入时,设这条边长度为len,则相当于把当前的这个凸壳先右移len,斜率大于1的部分斜率都重置为1,斜率小于1的部分都向左移len再向上移len,其中空出来的长度为len的部分用斜率为-1的连接起来。
就是把原凸壳先整体上移len,再删掉斜率大于等于0的部分,再添上3条斜率分别为-1,0,1的直线。
直接维护凸壳的复杂度是\(O\left((n+m)^2\right)\)的。
再来考虑一下凸壳的性质:
一个凸壳在x=0处的取值是子树内所有边权和;
当这个凸壳没有考虑当前点到它父亲的边的贡献时,这个凸壳最右端的直线的斜率是它的儿子数;
凸壳上的直线的斜率只可能是整数;
现在有了上面的性质,可以更简单地表达一个凸壳。
有了凸壳在x=0处的取值,我们只要知道一个凸壳的导函数就可以还原出一个凸壳。
有了凸壳最右端直线的斜率,也就是导函数的最大值,我们只要知道一个凸壳的二阶导就可以还原出凸壳的导函数。
也就是说不用维护凸壳,直接维护凸壳的二阶导数就可以了。
二阶导可以更直观的看成拐点,每个在第i个位置的拐点对二阶导的贡献为1(拐点的位置可以重叠)。
每次合并时直接合并两个拐点集合就可以了,每次考虑父亲边的贡献时删掉最靠右边的儿子数+1个拐点,再添加两个拐点。
因为每次都删权值最大的拐点,拐点集合可以用可并堆维护。
最后用根节点的拐点集合还原出根节点的凸壳就可以了。
每个节点只可能加进来两个拐点,每个拐点最多被弹出一次,时间复杂度\(O\left((n+m)\log(m+m)\right)\)

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;

const int N = 600003;

struct node {
	node *ch[2];
	int v; ll pos;
	node(ll _pos = 0) : pos(_pos) {ch[0] = ch[1] = 0; v = 0;}
} *rt[N];

int dist(node *r) {return r ? r->v : -1;}

node *merge(node *l, node *r) {
	if (l == 0) return r;
	if (r == 0) return l;
	if (l->pos < r->pos) swap(l, r);
	l->ch[1] = merge(l->ch[1], r);
	if (dist(l->ch[0]) < dist(l->ch[1]))
		swap(l->ch[0], l->ch[1]);
	if (l->ch[1]) l->v = l->ch[1]->v + 1;
	else l->v = 0;
	return l;
}

void pop(node *&r) {
	if (r) r = merge(r->ch[0], r->ch[1]);
}

ll sum = 0, pp[N << 1];
int fa[N << 1], len[N << 1], n, m, d[N << 1];

int main() {
	scanf("%d%d", &n, &m);
	for (int i = 2; i <= n + m; ++i) {
		scanf("%d%d", fa + i, len + i);
		sum += len[i];
		++d[fa[i]];
	}
	
	node *n1, *n2;
	for (int i = n + m; i > 1; --i) {
		if (i > n) {
			rt[i] = merge(new node(len[i]), new node(len[i]));
			rt[fa[i]] = merge(rt[fa[i]], rt[i]);
			continue;
		}
		while (--d[i]) pop(rt[i]);
		n2 = rt[i]; pop(rt[i]);
		n1 = rt[i]; pop(rt[i]);
		rt[i] = merge(rt[i], new node(n1->pos + len[i]));
		rt[i] = merge(rt[i], new node(n2->pos + len[i]));
		rt[fa[i]] = merge(rt[fa[i]], rt[i]);
	}
	
	while (d[1]--) pop(rt[1]);
	int ptot = 0;
	while (rt[1]) {
		pp[++ptot] = rt[1]->pos;
		pop(rt[1]);
	}
	
	ll prepos = 0;
	while (ptot) {
		if (pp[ptot] != prepos) {
			sum -= (pp[ptot] - prepos) * ptot;
			prepos = pp[ptot];
		}
		--ptot;
	}
	
	printf("%lld\n", sum);
	return 0;
}
posted @ 2017-04-25 07:54  abclzr  阅读(685)  评论(0编辑  收藏  举报