loading

CF1394D Boboniu and Jianghu

题意

给定一棵树,每个节点有权值 \(a_i,b_i\)。现在你现在要将整棵树进行链剖分,使得每条边都恰好在一条链里,并且每条链上的 \(b_i\) 单不降或单不升。链剖分的权值和为所有链的链上的点的 \(a_i\) 之和,求最小权值和链剖分。

\(n\le 2\times10^5\)

分析

我们不妨将单不降或单不升的链用方向区分开以简化问题,那么我们所需要做的就是给树上的每条边定向(不妨钦定边的终点的 \(b_i\) 较小)。那么,对于一条边来说,若 \(b_u\neq b_v\),则该边方向已经确定;而若 \(b_u=b_v\),则该边方向任意。

考虑将权值和拆到每个点上,计算每个点被不同链的覆盖次数。那么显然的,对于一个点,设指入该点的点数为 \(x\),由该点指出的点数为 \(y\),则最小覆盖次数为 \(\max(x,y)\)(因为一个指入和一个指出能配成一条链)。

考虑 DP,设 \(f_{i,0/1}\) 表示 \(i\) 这个点,和父亲的连边的定向状态为指向父亲节点/被父亲节点指。假设我们已经得出了所有子节点的 DP 值,要求出 \(\max(x,y)\) 可以考虑枚举,我们需要在可以自由定向的那些子 DP 状态选择一部分取 \(f_{u,0}\),另一部分取 \(f_{u,1}\)。我们可以通过钦定最开始全选 \(f_{u,0}\) 然后贪心的选择 \(f_{u,1}-f_{u,0}\) 最小的那些点换成 \(f_{u,1}\)。至此 DP 转移可以简单的做到 \(O(n\log n)\),总时间复杂度也是这个东西。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
#include<cmath>
#include<map>
#include<unordered_map>
#include<vector>
#include<queue>
#include<stack>
#include<bitset>
#include<set>
#include<array>
#include<ctime>
#include<random>
#include<cassert>
#define x1 xx1
#define y1 yy1
#define IOS ios::sync_with_stdio(false)
#define ITIE cin.tie(0);
#define OTIE cout.tie(0);
#define PY puts("Yes")
#define PN puts("No")
#define PW puts("-1")
#define P0 puts("0")
#define P__ puts("")
#define PU puts("--------------------")
#define mp make_pair
#define fi first
#define se second
#define gc getchar
#define pc putchar
#define pb emplace_back
#define un using namespace
#define all(x) x.begin(),x.end()
#define mem(x,y) memset(x,y,sizeof x)
#define rep(a,b,c) for(int a=(b);a<=(c);++a)
#define per(a,b,c) for(int a=(b);a>=(c);--a)
#define reprange(a,b,c,d) for(int a=(b);a<=(c);a+=(d))
#define perrange(a,b,c,d) for(int a=(b);a>=(c);a-=(d))
#define graph(i,j,k,l) for(int i=k[j];i;i=l[i].nxt)
#define lowbit(x) ((x)&-(x))
#define lson(x) ((x)<<1)
#define rson(x) ((x)<<1|1)
//#define double long double
#define int long long
//#define int __int128
using namespace std;
using i64=long long;
using u64=unsigned long long;
using pii=pair<int,int>;
template<typename T1,typename T2>inline void ckmx(T1 &x,T2 y){x=x>y?x:y;}
template<typename T1,typename T2>inline void ckmn(T1 &x,T2 y){x=x<y?x:y;}
inline auto rd(){
	int qwqx=0,qwqf=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')qwqf=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){qwqx=(qwqx<<1)+(qwqx<<3)+ch-48;ch=getchar();}return qwqx*qwqf;
}
template<typename T>inline void write(T qwqx,char ch='\n'){
	if(qwqx<0){qwqx=-qwqx;putchar('-');}
	int qwqy=0;char qwqz[40];
	while(qwqx||!qwqy){qwqz[qwqy++]=qwqx%10+48;qwqx/=10;}
	while(qwqy--)putchar(qwqz[qwqy]);if(ch)putchar(ch);
}
bool Mbg;
const int maxn=2e5+5,inf=0x3f3f3f3f;
const long long llinf=0x3f3f3f3f3f3f3f3f;
int n,a[maxn],b[maxn];
vector<int>G[maxn];
int f[maxn][2];
int ans=llinf;
void dfs(int x,int y){
	int in=0,out=0;
	int sum=0;
	priority_queue<int,vector<int>,greater<int> >q;
	for(int u:G[x])if(u^y){
		dfs(u,x);
		if(b[x]==b[u]){
			sum+=f[u][0],++in;
			q.push(f[u][1]-f[u][0]);
		}
		if(b[x]>b[u]){
			sum+=f[u][1],++out;
		}
		if(b[x]<b[u]){
			sum+=f[u][0],++in;
		}
	}
	if(y){
		ckmn(f[x][0],max(out+1,in)*a[x]+sum),ckmn(f[x][1],max(out,in+1)*a[x]+sum);
		while(!q.empty()){
			sum+=q.top(),q.pop(),--in,++out;
			ckmn(f[x][0],max(out+1,in)*a[x]+sum),ckmn(f[x][1],max(out,in+1)*a[x]+sum);
		}
	}else{
		ckmn(ans,max(out,in)*a[x]+sum);
		while(!q.empty()){
			sum+=q.top(),q.pop(),--in,++out;
			ckmn(ans,max(out,in)*a[x]+sum);
		}
	}
}
inline void solve_the_problem(){
	n=rd();
	rep(i,1,n)a[i]=rd();
	rep(i,1,n)b[i]=rd();
	rep(i,2,n){
		int x=rd(),y=rd();
		G[x].emplace_back(y),G[y].emplace_back(x);
	}
	mem(f,0x3f);
	dfs(1,0);
	write(ans);
}
bool Med;
signed main(){
//	freopen(".in","r",stdin);freopen(".out","w",stdout);
	fprintf(stderr,"%.3lfMB\n",(&Mbg-&Med)/1048576.0);
	int _=1;
	while(_--)solve_the_problem();
}
/*

*/
posted @ 2025-04-30 20:08  dcytrl  阅读(20)  评论(2)    收藏  举报