P4383 [八省联考 2018] 林克卡特树

简化题意,给一棵树,找出恰好 \(k+1\) 条链,使这些链之和最大。

有恰好选出的字眼,并且原问题显然具有凸性,直接考虑 wqs 二分。

然后每条链会减去二分的 \(mid\),接下来就没有限制,求最大链和及链的数量,考虑树形 dp。

\(f_{x,0/1/2}\) 表示以 \(x\) 为根的子树,\(x\) 点入度为 \(0/1/2\) 所得的最大链值。

特别地,每次转移完 \(x\) 后令 \(f_{x,0}=\max(f_{x,0},f_{x,1}+mid,f_{x,2})\),表示 \(x\) 不再和祖先进行链的合并,此时得到的最优解。

接下来就能转移了,设 \(y\)\(x\) 的儿子,转移分三步:

\(f_{x,2}=\max(f_{x,2}+f_{y,0},f_{x,1}+f_{y,1}+w_{(x,y)}+mid)\)

\(f_{x,1}=\max(f_{x,1}+f_{y,0},f_{x,0}+f_{y,1}+w_{(x,y)})\)

\(f_{x,0}=\max(f_{x,0}+f_{y,0})\)

最后判断链数和 \(k\) 的大小关系,就做完了。

注意边权相等时,链数少的状态优先转移,因为维护的是上凸包。

#include<bits/stdc++.h>
using namespace std;
#define rd read()
#define ll long long
#define FOR(i,j,k) for(int i=j;i<=k;i++)
#define ROF(i,j,k) for(int i=j;i>=k;i--)
int read(){
	int x=0,f=1;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)) x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
	return x*f;
}
const int N=3e5+10;
const ll INF=1e12;
int n,k,head[N],tot;
struct node{int to,nxt,w;}edge[N<<1];
void add(int x,int y,int w){edge[++tot]={y,head[x],w};head[x]=tot;}
struct Node{
	ll val;int cnt;
	Node(ll x=0,int y=0):val(x),cnt(y){}
	friend bool operator<(Node a,Node b){return a.val!=b.val?a.val<b.val:a.cnt>b.cnt;}
	friend Node operator+(Node a,Node b){return Node{a.val+b.val,a.cnt+b.cnt};}
	friend Node operator+(Node a,ll v){return Node{a.val+v,a.cnt};}
}f[N][3],tmp;
void dfs(int x,int fa){
	f[x][0]=f[x][1]=f[x][2]=Node();
	f[x][2]=max(f[x][2],tmp);
	for(int i=head[x];i;i=edge[i].nxt){
		int y=edge[i].to;if(y==fa) continue;
		dfs(y,x);
		f[x][2]=max(f[x][2],max(f[x][2]+f[y][0],f[x][1]+f[y][1]+edge[i].w+tmp));
		f[x][1]=max(f[x][1],max(f[x][1]+f[y][0],f[x][0]+f[y][1]+edge[i].w));
		f[x][0]=max(f[x][0],f[x][0]+f[y][0]);
	}
	f[x][0]=max(f[x][0],max(f[x][1]+tmp,f[x][2]));
}
int main(){
	n=rd,k=rd,k++;
	FOR(i,1,n-1){int x=rd,y=rd,w=rd;add(x,y,w),add(y,x,w);}
	ll l=-INF,r=INF,ans=0 ;
	while(l<r){
		ll mid=(l+r+1)/2ll;
		tmp=Node{-mid,1};dfs(1,0);
		if(f[1][0].cnt==k){printf("%lld\n",f[1][0].val+k*mid);return 0;}
		if(f[1][0].cnt>k) l=mid+1;
		else r=mid;
	}
	tmp=Node{-l,1};dfs(1,0),printf("%lld\n",f[1][0].val+k*l);
	return 0;
}
posted @ 2024-11-05 14:34  summ1t  阅读(44)  评论(5)    收藏  举报