Bzoj_3572 [Hnoi2014]世界树

题意
还是一个n个结点的树,每次询问还是选定k个点。规定每个点会给距离它最近的标记点(距离相同,编号最小)贡献1的权值,每次询问即是标记k个点,然后问这k个点的权值。
\(N \leqslant 300000, q \leqslant 300000,k_1+k_2+...+k_q \leqslant 300000\)

题解
算法还是虚树。
然后有一个显然的性质。
对于每个点x,它的所有儿子的子树,内部有标记点的子树除外,所贡献的点一定与点x相同。
然后我们先只考虑虚树上的结点。
在虚树上,很容易可以算出每个点距离最近的标记点。
那么我们也可以连带把每个点连带的子树的贡献也统计好。
那么还剩下哪些点没统计呢?
虚树边上的点以及它们所连带的子树。
而对于边<u,v>上的点,一定贡献给u或v所贡献的点。
我们画个图考虑一下。

点X和点Y为虚树上的点(非标记点)。距离点X最近的标记点的编号为5,到点X的距离为3,到点Y最近的标记点的编号为2,距离为4。首先虚树边<X,Y>上与点x最近的点a所连带的1号子树,一定是贡献给5号点,同时得到5号点到点a的距离为4,与到2号点到点Y的距离相等。那么这之间的点一定被2号点与5号点平分,上面一半归5号点,下面一般归2号点,但此时存在点c,到2号点和5号点的距离均相等,采取距离相等,编号最小的原则,点c及2号子树,归2号点所有,也就是说点c及以下贡献2号点,点b及以上贡献5号点,到此分配完成。

简单来说,我们对于每个虚树边<u,v>,我们根据两端点贡献点的信息,算出边上的断点,断点以上给点u的贡献点,断点以下给点v的贡献点。具体详见代码。

#include<cmath>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define inf 0x7f7f7f
using namespace std;
const int maxn=3e5;
int n,m,tot,root,Time;
int Lg[maxn+8],a[maxn+8];
int pre[maxn*2+8],now[maxn+8],son[maxn*2+8];
int dep[maxn+8],f[maxn+8][20],siz[maxn+8],dfn[maxn+8];
int st[maxn+8],ans[maxn+8];

int read()
{
    int x=0,f=1;char ch=getchar();
    for (;ch<'0'||ch>'9';ch=getchar()) if (ch=='-') f=-1;
    for (;ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
    return x*f;
}
bool cmp(int x,int y){return dfn[x]<dfn[y];}

void add(int u,int v)
{
    pre[++tot]=now[u];
    now[u]=tot;
    son[tot]=v;
}

void dfs(int x,int fa)
{
    dfn[x]=++Time;
    dep[x]=dep[fa]+1;
    f[x][0]=fa;
    siz[x]=1;
    for (int i=1;i<=log(dep[x])/log(2);i++) f[x][i]=f[f[x][i-1]][i-1];
    for (int p=now[x];p;p=pre[p])
	{
	    int child=son[p];
	    if (child==fa) continue;
	    dfs(child,x);
	    siz[x]+=siz[child];
	}
}

int jump(int x,int d){if (d<0) return 0;for (;d;d-=d&(-d)) x=f[x][Lg[d&(-d)]];return x;}

int Get_Lca(int x,int y)
{
    if (dep[x]>dep[y]) swap(x,y);
    y=jump(y,dep[y]-dep[x]);
    if (x==y) return x;
    for (int i=log(dep[x])/log(2);~i&&f[x][0]!=f[y][0];i--)
	if (f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
    return f[x][0];
}

struct Pnt
{
    int x,dis;
};

bool operator <(Pnt a,Pnt b){return a.dis!=b.dis?a.dis<b.dis:a.x<b.x;}
Pnt operator +(Pnt a,int b){return (Pnt){a.x,a.dis+b};}

struct Virtual_Tree
{
    int tot,tail;
    int st[maxn+8];
    int pre[maxn*2+8],now[maxn+8],son[maxn*2+8],val[maxn*2+8];
    int color[maxn+8],cnt[maxn+8];
    Pnt f[maxn+8];
    void clear()
    {
	tot=0;
	while(tail) now[st[tail--]]=0;
	for (int i=1;i<=m;i++) color[a[i]]=0,ans[i]=0;
    }
    void add(int u,int v,int w)
    {
	if (!now[u]) st[++tail]=u;
	pre[++tot]=now[u];
	now[u]=tot;
	son[tot]=v;
	val[tot]=w;
    }
    void insert(int u,int v)
    {
	if (dep[u]>dep[v]) swap(u,v);
	add(u,v,dep[v]-dep[u]);
	add(v,u,dep[v]-dep[u]);
    }
    void dfs1(int x,int fa)
    {
	f[x]=(Pnt){x,color[x]?0:inf};
	cnt[x]=siz[x];
	for (int p=now[x];p;p=pre[p])
	    {
		int child=son[p];
		if (child==fa) continue;
		dfs1(child,x);
	        f[x]=min(f[x],f[child]+val[p]);
		cnt[x]-=siz[jump(child,val[p]-1)];
	    }
    }
    void dfs2(int x,int fa)
    {
	for (int p=now[x];p;p=pre[p])
	    {
		int child=son[p];
		if (child==fa) continue;
		f[child]=min(f[child],f[x]+val[p]);
		dfs2(child,x);
	    }
	for (int p=now[x];p;p=pre[p])
	    {
		int child=son[p];
		if (child==fa) continue;
		int d=f[x].dis-f[child].dis+val[p]-1,tmp1=siz[jump(child,val[p]-1)]+cnt[child]-siz[child],tmp2=siz[jump(child,d/2+((d>0)&(d&1)&(f[child].x<f[x].x)))]+cnt[child]-siz[child];
		tmp2=max(tmp2,0);tmp1-=tmp2;
	        ans[color[f[x].x]]+=tmp1;
		ans[color[f[child].x]]+=tmp2;
	    }
    }
}VT;

void solve()
{
    m=read();
    for (int i=1;i<=m;i++) VT.color[a[i]=read()]=i;
    sort(a+1,a+m+1,cmp);
    int tail=1;
    st[tail]=root;
    for (int i=1;i<=m;i++)
	{
	    int Lca=Get_Lca(st[tail],a[i]),lst=0;
	    while(dep[Lca]<dep[st[tail]])
		{
		    if (lst) VT.insert(lst,st[tail]);
		    lst=st[tail--];
		}
	    if (lst) VT.insert(lst,Lca);
	    if (dep[Lca]!=dep[st[tail]]) st[++tail]=Lca;
	    st[++tail]=a[i];
	}
    while(tail!=1) VT.insert(st[tail],st[tail-1]),tail--;
    VT.dfs1(root,0);
    VT.dfs2(root,0);
    for (int i=1;i<=m;i++) printf("%d ",ans[i]);puts("");
    VT.clear();
}

int main()
{
    n=read();
    for (int i=0;i<=20;i++) Lg[1<<i]=i;
    for (int i=1;i<n;i++)
	{
	    int u=read(),v=read();
	    add(u,v),add(v,u);
	}
    root=n+1;
    add(root,1),add(1,root);
    dfs(root,0);
    int Q=read();
    while(Q--) solve();
    return 0;
}
posted @ 2019-01-04 10:57  Alseo_Roplyer  阅读(132)  评论(0编辑  收藏  举报