P3714 [BJOI2017]树的难题 题解
题意简述:给出一棵树,每条边有一个颜色(可以相同),每种颜色有一个权值,对于一条树上的简单路径,路径上经过的所有边按顺序组成一个颜色序列,序列可以划分成若干个相同颜色段。定义路径权值为颜色序列上每个同颜色段的颜色权值之和。求经过边数在 \(l\) 到 \(r\) 之间的所有简单路径中,路径权值的最大值。
Solution:对于这种树上路径统计的题,考虑点分治。
在某次点分治的过程中,要统计两个东西:一端为根节点(当前,下同)的路径,经过根且两端点不为根节点的路径
对于1,可以用一遍\(DFS\)处理出来
对于2,它是由两个1中路径合并而来,当合并时,我们不关心具体颜色,只关心两条路径相接的地方是否同色,我们定义路径与根相接的边的颜色为这条路径的"颜色"
为了减少合并次数,我们把颜色排序,把相同颜色的路径合在一起处理
对于一条路径,分别讨论它与异色和同色路径相连,且路径长度满足要求时的最大答案
以路径长度为下标,开两个线段树维护,第一棵线段树维护异色路径的答案,当颜色改变时,就合并两棵线段树(变为下一种颜色的第一棵线段树)
于是可以用线段树合并解决
时间复杂度\(O(nlog^2n)\)
trick:在统计2时,不能把同一子树内的两条路径接起来,所以我们在统计某棵子树内的答案时,先把这棵子树内的所有答案记下来,最后再把它们加进线段树
把这些答案也塞进一棵线段树中,然后再合并,可以减小代码难度
注意线段树要及时清空
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <queue>
using namespace std;
#define int long long
const int INF = 1234567891011;
const int N = 200005;
typedef pair<int,int> pii;
vector<pii> G[N];
int cnt,n,m,l,r,Rt,rt1,rt2;
int Max[N],siz[N],c[N],vis[N],sum,rt,ans = -INF;
inline int read() {
int x = 0,f = 1;
char ch = getchar();
while(ch < '0'||ch > '9') {
if(ch == '-')
f = -1;
ch = getchar();
}
while(ch >= '0' && ch <= '9') {
x = (x<<1) + (x<<3) + (ch^48);
ch = getchar();
}
return x*f;
}
inline void add(int x,int y,int z) {
// cout << "edge:" << x << " " << y << " " << z << endl;
G[x].push_back(make_pair(z,y));
// G[y].push_back(make_pair(z,x));
}
struct Seg{
int lc,rc,maxx;
#define lc(x) tr[x].lc
#define rc(x) tr[x].rc
#define maxx(x) tr[x].maxx
}tr[N * 82];
inline int build() {
++cnt;
lc(cnt) = rc(cnt) = 0;
maxx(cnt) = -INF;
return cnt;
}
inline void clear(int x) {
lc(x) = rc(x) = 0;
maxx(x) = -INF;
return;
}
inline void Clear(int x) {
if(lc(x)) Clear(lc(x));
if(rc(x)) Clear(rc(x));
clear(x);
return;
}
inline void pushup(int x) {
maxx(x) = max(maxx(lc(x)),maxx(rc(x)));
}
inline void change(int x,int l,int r,int pos,int val) {
if(l == r) {
maxx(x) = max(maxx(x),val);
return;
}
int mid = (l+r)>>1;
if(pos <= mid) {
if(!lc(x)) lc(x) = build();
change(lc(x),l,mid,pos,val);
}
else {
if(!rc(x)) rc(x) = build();
change(rc(x),mid + 1,r,pos,val);
}
pushup(x);
}
inline int merge(int p,int q,int l,int r) {
if(!p || !q) return p + q;
if(l == r) {
maxx(p) = max(maxx(p),maxx(q));
clear(q);
return p;
}
int mid = (l + r) >> 1;
lc(p) = merge(lc(p),lc(q),l,mid);
rc(p) = merge(rc(p),rc(q),mid+1,r);
pushup(p);
clear(q);
return p;
}
inline int query(int p,int l,int r,int x,int y) {
//[x,y]为待查询区间
if(!p) return -INF;
if(x <= l && r <= y) return maxx(p);
int mid = (l+r)>>1;
if(x > mid) return query(rc(p),mid+1,r,x,y);
if(y <= mid) return query(lc(p),l,mid,x,y);
return max(query(lc(p),l,mid,x,y),query(rc(p),mid+1,r,x,y));
}
//----------------------
inline void calcsiz(int x,int fa) {
siz[x] = 1,Max[x] = 0;
for(int i = 0;i < G[x].size();i++) {
int y = G[x][i].second;
if(y == fa || vis[y]) continue;
calcsiz(y,x);
siz[x] += siz[y];
Max[x] = max(Max[x],siz[y]);
}
Max[x] = max(Max[x],sum - siz[x]);
if(Max[x] < Max[rt]) rt = x;
}
inline void work(int x,int fa,int in,int In,int len,int now) {
if(len > r) return;
// cout << "work:" << x << " " << fa << " " << in << " " << In << " " << len << " " << now << endl;
ans = max(ans,query(rt1,1,r,max(l - len,1LL), max(r - len,1LL)) + now);
// cout << "cnm" << endl;
ans = max(ans,query(rt2,1,r,max(l - len,1LL), max(r - len,1LL)) + now - c[in]);
// cout << "cnm" << endl;
if(len >= l) ans = max(ans,now);
change(Rt,1,r,len,now);
for(int i = 0;i < G[x].size();i++) {
int y = G[x][i].second,z = G[x][i].first;
if(y == fa || vis[y]) continue;
if(z == In) work(y,x,in,In,len + 1,now);
else work(y,x,in,z,len + 1,now + c[z]);
}
}
inline void solve(int x,int fa) {
// cout << x << " " << fa << endl;
vis[x] = 1;rt1 = build(),rt2 = build();
sort(G[x].begin(),G[x].end());
for(int i = 0;i < G[x].size();i++) {
int y = G[x][i].second,z = G[x][i].first;
if(y == fa || vis[y]) continue;
// cout << "edge:" << y << " " << z << endl;
if(i && G[x][i].first != G[x][i-1].first) rt1 = merge(rt1,rt2,1,r),rt2 = build();
// cout << "fuck" << endl;
Rt = build();
work(y,x,z,z,1,c[z]);
// cout << "ok" << endl;
merge(rt2,Rt,1,r);
// cout << "shit" << endl;
}
cnt = 0;
Clear(rt1);Clear(rt2);
for(int i = 0;i < G[x].size();i++) {
int y = G[x][i].second;
if(y == fa || vis[y]) continue;
rt = 0,sum = siz[y];
Max[rt] = INF;
calcsiz(y,x);
calcsiz(rt,-1);
solve(rt,x);
}
}
signed main () {
tr[0].maxx = -INF;
n = read();m = read();l = read();r = read();
for(int i = 1;i <= m;i++) c[i] = read();
for(int i = 1;i < n;i++) {
int u,v,w;
u = read();v = read();w = read();
add(u,v,w);add(v,u,w);
}
rt = 0,sum = n;
Max[rt] = INF;
calcsiz(1,-1);
calcsiz(rt,-1);
// cout << "rt:" << rt << endl;
solve(rt,-1);
printf("%lld\n",ans);
return 0;
}

浙公网安备 33010602011771号