P3714 [BJOI2017]树的难题 题解

原题戳这里

题意简述:给出一棵树,每条边有一个颜色(可以相同),每种颜色有一个权值,对于一条树上的简单路径,路径上经过的所有边按顺序组成一个颜色序列,序列可以划分成若干个相同颜色段。定义路径权值为颜色序列上每个同颜色段的颜色权值之和。求经过边数在 \(l\)\(r\) 之间的所有简单路径中,路径权值的最大值。

Solution:对于这种树上路径统计的题,考虑点分治。
在某次点分治的过程中,要统计两个东西:一端为根节点(当前,下同)的路径,经过根且两端点不为根节点的路径
对于1,可以用一遍\(DFS\)处理出来
对于2,它是由两个1中路径合并而来,当合并时,我们不关心具体颜色,只关心两条路径相接的地方是否同色,我们定义路径与根相接的边的颜色为这条路径的"颜色"
为了减少合并次数,我们把颜色排序,把相同颜色的路径合在一起处理
对于一条路径,分别讨论它与异色和同色路径相连,且路径长度满足要求时的最大答案
以路径长度为下标,开两个线段树维护,第一棵线段树维护异色路径的答案,当颜色改变时,就合并两棵线段树(变为下一种颜色的第一棵线段树)
于是可以用线段树合并解决
时间复杂度\(O(nlog^2n)\)

trick:在统计2时,不能把同一子树内的两条路径接起来,所以我们在统计某棵子树内的答案时,先把这棵子树内的所有答案记下来,最后再把它们加进线段树
把这些答案也塞进一棵线段树中,然后再合并,可以减小代码难度

注意线段树要及时清空

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <queue>

using namespace std;

#define int long long
const int INF = 1234567891011;
const int N = 200005;
typedef pair<int,int> pii;
vector<pii> G[N];
int cnt,n,m,l,r,Rt,rt1,rt2;
int Max[N],siz[N],c[N],vis[N],sum,rt,ans = -INF;
inline int read() {
    int x = 0,f = 1;
    char ch = getchar();
    while(ch < '0'||ch > '9') {
        if(ch == '-')
            f = -1;
        ch = getchar();
    }
    while(ch >= '0' && ch <= '9') {
        x = (x<<1) + (x<<3) + (ch^48);
        ch = getchar();
    }
    return x*f;
}

inline void add(int x,int y,int z) {
//	cout << "edge:" << x << " " << y << " " << z << endl;
	G[x].push_back(make_pair(z,y));
//	G[y].push_back(make_pair(z,x));
}

struct Seg{
	int lc,rc,maxx;
	#define lc(x) tr[x].lc
	#define rc(x) tr[x].rc
	#define maxx(x) tr[x].maxx
}tr[N * 82];

inline int build() {
	++cnt;
	lc(cnt) = rc(cnt) = 0;
	maxx(cnt) = -INF;
	return cnt;
}

inline void clear(int x) {
	lc(x) = rc(x) = 0;
	maxx(x) = -INF;
	return;
}

inline void Clear(int x) {
	if(lc(x)) Clear(lc(x));
	if(rc(x)) Clear(rc(x));
	clear(x);
	return;
}

inline void pushup(int x) {
	maxx(x) = max(maxx(lc(x)),maxx(rc(x)));
}

inline void change(int x,int l,int r,int pos,int val) {
	if(l == r) {
		maxx(x) = max(maxx(x),val);
		return;
	}
	int mid = (l+r)>>1;
	if(pos <= mid) {
		if(!lc(x)) lc(x) = build();
		change(lc(x),l,mid,pos,val);
	}
	else {
		if(!rc(x)) rc(x) = build();
		change(rc(x),mid + 1,r,pos,val);
	}
	pushup(x);
}

inline int merge(int p,int q,int l,int r) {
	if(!p || !q) return p + q;
	if(l == r) {
		maxx(p) = max(maxx(p),maxx(q));
		clear(q);
		return p;
	}
	int mid = (l + r) >> 1;
	lc(p) = merge(lc(p),lc(q),l,mid);
	rc(p) = merge(rc(p),rc(q),mid+1,r);
	pushup(p);
	clear(q);
	return p;
}

inline int query(int p,int l,int r,int x,int y) {
	//[x,y]为待查询区间 
	if(!p) return -INF;
	if(x <= l && r <= y) return maxx(p);
	int mid = (l+r)>>1;
	if(x > mid) return query(rc(p),mid+1,r,x,y);
	if(y <= mid) return query(lc(p),l,mid,x,y); 
	return max(query(lc(p),l,mid,x,y),query(rc(p),mid+1,r,x,y));
}

//----------------------

inline void calcsiz(int x,int fa) {
	siz[x] = 1,Max[x] = 0;
	for(int i = 0;i < G[x].size();i++) {
		int y = G[x][i].second;
		if(y == fa || vis[y]) continue;
		calcsiz(y,x);
		siz[x] += siz[y];
		Max[x] = max(Max[x],siz[y]);
	}
	Max[x] = max(Max[x],sum - siz[x]);
	if(Max[x] < Max[rt]) rt = x;
}

inline void work(int x,int fa,int in,int In,int len,int now) { 
	if(len > r) return;
//	cout << "work:" << x << " " << fa << " " << in << " " << In << " " << len << " " << now << endl; 
	ans = max(ans,query(rt1,1,r,max(l - len,1LL), max(r - len,1LL)) + now);
//	cout << "cnm" << endl; 
	ans = max(ans,query(rt2,1,r,max(l - len,1LL), max(r - len,1LL)) + now - c[in]);
//	cout << "cnm" << endl;
	if(len >= l) ans = max(ans,now);
	change(Rt,1,r,len,now);
	for(int i = 0;i < G[x].size();i++) {
		int y = G[x][i].second,z = G[x][i].first;
		if(y == fa || vis[y]) continue;
		if(z == In) work(y,x,in,In,len + 1,now); 
		else work(y,x,in,z,len + 1,now + c[z]);
	}
} 

inline void solve(int x,int fa) {
//	cout << x << " " << fa << endl;
	vis[x] = 1;rt1 = build(),rt2 = build();
	sort(G[x].begin(),G[x].end());
	for(int i = 0;i < G[x].size();i++) {
		int y = G[x][i].second,z = G[x][i].first;
		if(y == fa || vis[y]) continue;
	//	cout << "edge:" << y << " " << z << endl;
		if(i && G[x][i].first != G[x][i-1].first) rt1 = merge(rt1,rt2,1,r),rt2 = build();
	//	cout << "fuck" << endl;
		Rt = build();
		work(y,x,z,z,1,c[z]);
	//	cout << "ok" << endl;
		merge(rt2,Rt,1,r);
	//	cout << "shit" << endl;
	}
	cnt = 0;
	Clear(rt1);Clear(rt2);
	for(int i = 0;i < G[x].size();i++) {
		int y = G[x][i].second;
		if(y == fa || vis[y]) continue;
		rt = 0,sum = siz[y];
		Max[rt] = INF;
		calcsiz(y,x);
		calcsiz(rt,-1);
		solve(rt,x);
	}
}

signed main () {
	tr[0].maxx = -INF;
	n = read();m = read();l = read();r = read();
	for(int i = 1;i <= m;i++) c[i] = read();
	for(int i = 1;i < n;i++) {
		int u,v,w;
		u = read();v = read();w = read();
		add(u,v,w);add(v,u,w); 
	}
	rt = 0,sum = n;
	Max[rt] = INF;
	calcsiz(1,-1);
	calcsiz(rt,-1);
//	cout << "rt:" << rt << endl;
	solve(rt,-1);
	printf("%lld\n",ans);
	return 0;
}
posted @ 2021-04-16 15:12  ctt2006  阅读(91)  评论(0)    收藏  举报