[spoj] FTOUR2 FREE TOUR II || 树分治

原题

给出一颗有n个点的树,其中有M个点是拥挤的,请选出一条最多包含k个拥挤的点的路径使得经过的权值和最大。


正常树分治,每次处理路径,更新答案。
计算每棵子树的deep(本题以经过拥挤节点个数作为deep),然后记录mx[i]为当前为止经过i个拥挤节点所达到的最大价值,tmp[i]为当前所在树中经过i个拥挤节点所达到的最大价值,用于更新答案即可。

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
#define N 200010
using namespace std;
int ans,n,K,m,cnt,head[N],f[N];
vector < pair<int,int> > v;
struct hhh
{
    int to,next,w;
}edge[2*N];

int read()
{
    int ans=0,fu=1;
    char j=getchar();
    for (;j<'0' || j>'9';j=getchar()) if (j=='-') fu=-1;
    for (;j>='0' && j<='9';j=getchar()) ans*=10,ans+=j-'0';
    return ans*fu;
}

void add(int u,int v,int w)
{
    edge[cnt].to=v;edge[cnt].next=head[u];edge[cnt].w=w;head[u]=cnt++;
    edge[cnt].to=u;edge[cnt].next=head[v];edge[cnt].w=w;head[v]=cnt++;
}

void getroot(int x,int fa)
{
    sze[x]=1;
    son[x]=0;
    for (int i=head[x];i;i=edge[i].next)
	if (!vis[edge[i].to] && edge[i].to!=fa)
	{
	    getroot(edge[i].to,x);
	    son[x]=max(son[x],sze[edge[i].to]);
	    sze[x]+=sze[edge[i].to];
	}
    son[x]=max(son[x],sum-sze[x]);
    if (son[x]<son[rt]) rt=x;
}

void getdis(int x,int fa)
{
    deep_mx=max(deep_mx,deep[x]);
    for (int i=head[x];i;i=edge[i].next)
	if (!vis[edge[i].to] && edfe[i].to!=fa)
	{
	    deep[edge[i].to]=deep[x]+color[edge[i].to];
	    dis[edge[i].to]=dis[x]+edge[i].w;
	    getdis(edge[i].to,x);
	}
}

void getmx(int x,int fa)
{
    tmp[deep[x]]=max(tmp[deep[x]],dis[x]);
    for (int i=head[x];i;i=edge[i].to)
	if (!vis[edge[i].to] && edge[i].to!=fa)
	    getmx(edge[i].to,x);
}

void solve(int x)
{
    vis[x]=1;
    v.clear();
    for (int i=head[x];i;i=edge[i].next)
	if (!vis[edge[i].to])
	{
	    deep_mx=0;
	    deep[edge[i].to]=color[edge[i].to];
	    dis[edge[i].to]=edge[i].ww;
	    getdis(edge[i].to,x);
	    v.push_back(make_pair(deep_mx,edge[i].to));
	}
    sort(v.begin(),v.end());
    int s=v.size();
    for (int i=0;i<s;i++)
    {
	getmx(st[i].second,x);
	int now=0;
	if (i!=0)
	    for (int j=v[i].first;j>=0;j--)
	    {
		while (now+j<K && now<st[i-1].first)
		    now++,mx[now]=max(mx[now],mx[now-1]);
		if (now+j<=K) ans=max(mx[now]+tmp[j]);
	    }
	if (i!=s-1)
	    for (int j=0;j<=v[i].first;j++)
		mx[j]=max(mx[j],tmp[j]),tmp[j]=0;
	else
	    for (int j=0;j<=v[i].first;j++)
	    {
		if (j<=K) ans=max(ans,max(tmp[j],mx[j]));
		tmp[j]=mx[j]=0;
	    }
    }
}

int main()
{
    n=read();
    K=read();
    m=read();
    for (int i=1;i<=m;i++)
    {
	int x=read();
	color[x]=1;
    }
    for (int i=1,u,v,w;i<n;i++)
    {
	u=read();v=read();w=read();
	add(u,v,w);
    }
    sum=n;
    f[0]=n;
    getroot(1,0);
    solve(rt);
    printf("%d",ans);
    return 0;
}
posted @ 2017-12-18 14:03  Mrha  阅读(203)  评论(0编辑  收藏  举报