【题解】CF1394D
题目描述
有一棵 \(n\) 个点的树(\(1\le n\le 2\times10^5\)),第 \(i\) 个点有参数 \(a_i,b_i\)。(\(1\le a_i,b_i\le10^6\))
现在要求把这棵树剖分成若干条链(链包括端点),使每条边恰好出现在一条链中,且要求链上的点的 \(b_i\) 单调不降或单调不增。一条链的权值定义为链上所有点的 \(a_i\) 之和。
求在所有剖分方案中,链的总权值最小为多少。
题解
可以观察到如果所有值都不同,那么答案一定是naive的,于是只考虑相同的点。
对于相同的点之间的边,一定是由下连到上的,于是我们可以对它进行定向做Dp,\(f_{i,0/1}\)代表第i个点及其它的子树内当前点到它的的fa的边的方向是上/下的最小和。
#include<bits/stdc++.h>
using namespace std;
#define int long long
inline int rd(){
int f=1,j=0;
char w=getchar();
while(!isdigit(w)){
if(w=='-')f=-1;
w=getchar();
}
while(isdigit(w)){
j=j*10+w-'0';
w=getchar();
}
return f*j;
}
const int N=200010;
int head[N],to[N*2],fro[N*2],tail;
int n,A[N],B[N],f[N][2],du[N];
struct node{
int p,val;
bool operator <(const node &a)const{return val<a.val;}
};
vector<node>own[N];
inline void adlin(int x,int y){
to[++tail]=y,fro[tail]=head[x],head[x]=tail;
return ;
}
priority_queue<node>que;
void dfs(int u,int fa){
int sum0=0,num0=0,sum1=0,num1=0;
for(int k=head[u];k;k=fro[k]){
int v=to[k];
du[u]++;
if(v==fa)continue;
dfs(v,u);
if(B[v]>B[u])sum0+=f[v][0],num0++;
else if(B[v]<B[u])sum1+=f[v][1],num1++;
else num0++,sum0+=f[v][0],own[u].push_back((node){v,f[v][0]-f[v][1]});
}
// cout<<u<<":"<<sum0<<" "<<num0<<"-"<<sum1<<" "<<num1<<"\n";
if(B[fa]<=B[u]){
int ansn=1e18;
while(!que.empty())que.pop();
int a=num0,b=num1,c=sum0,d=sum1;
for(int i=0;i<own[u].size();i++){
que.push(own[u][i]);
}
// cout<<u<<":"<<a<<" "<<c<<"-"<<b<<" "<<d<<"\n";
ansn=min(ansn,c+d+max(a,b+(u!=1))*A[u]);
while(!que.empty()){
int p=que.top().p;que.pop();
// cout<<u<<":"<<a<<" "<<c<<"-"<<b<<" "<<d<<"\n";
a--,c-=f[p][0],b++,d+=f[p][1];
ansn=min(ansn,c+d+max(a,b+(u!=1))*A[u]);
}
f[u][0]=ansn;
// if(du[u]==1&&fa!=0)f[u][0]=A[u];
}
if(B[fa]>=B[u]){
int ansn=1e18;
while(!que.empty())que.pop();
int a=num0,b=num1,c=sum0,d=sum1;
for(int i=0;i<own[u].size();i++){
que.push(own[u][i]);
}
ansn=min(ansn,c+d+max(a+(u!=1),b)*A[u]);
while(!que.empty()){
int p=que.top().p;que.pop();
a--,c-=f[p][0],b++,d+=f[p][1];
ansn=min(ansn,c+d+max(a+(u!=1),b)*A[u]);
}
f[u][1]=ansn;
// if(du[u]==1&&fa!=0)f[u][1]=A[u];
}
return ;
}
signed main(){
n=rd();
for(int i=1;i<=n;i++)A[i]=rd();
for(int i=1;i<=n;i++)B[i]=rd();
for(int i=1;i<n;i++){
int x=rd(),y=rd();
adlin(x,y),adlin(y,x);
}
dfs(1,0);
// for(int i=1;i<=n;i++)cout<<i<<":"<<f[i][0]<<" "<<f[i][1]<<"\n";
printf("%lld",f[1][0]);
return 0;
}

浙公网安备 33010602011771号