【洛谷】P3591 [POI2015] ODW(根号分治+长链剖分)

原题链接

题意

给定一棵 \(n\) 个点的树,树上每条边的长度都为 \(1\),第 \(i\) 个点的权值为 \(a_i\)

Byteasar 想要走遍这整棵树,他会按照某个 \(1\)\(n\) 的全排列 \(b\)\(n-1\) 次,第 \(i\) 次他会从 \(b_i\) 点走到 \(b_{i + 1}\) 点,并且这一次的步伐大小为 \(c_i\)

对于一次行走,假设起点为 \(x\),终点为 \(y\),步伐为 \(k\),那么Byteasar会从 \(x\) 开始,每步往前走 \(k\) 条边,数据保证了每次行走的距离是 \(k\) 的倍数

请帮助 Byteasar 统计出每一次行走时经过的所有点的权值和。

\(n \leq 50000\)

思路

不难发现,本题中的经过的路径点数随着 \(k\) 的增大而减小,于是可以考虑根号分治。

  • \(k \leq \sqrt{n}\) 时,记 \(sum_{i,j}\) 为从 \(i\) 节点往根节点走 \(k\) 的大小,所能走到的权值和。这个可以 \(O(n \sqrt{n})\) 预处理得出。每次询问的时候只需特判一下 \(\mathrm{LCA}\) 即可。

  • \(k > \sqrt{n}\) 时,只需要暴力往 \(k\) 级祖先跳即可。但是倍增求 \(k\) 级祖先的复杂度要带个 \(\log\),用到长链剖分,可以每次做到 \(O(1)\) 跳到 \(k\) 级祖先。

code:

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
bool mem1;
const int N=5e4+10,B=233;
int dep[N],maxd[N],b[N],c[N],h[N],idx,n,a[N],son[N],top[N],sum[N][B],fa[N][21],lg[N];
vector<int>up[N],down[N];
struct edge{int v,nex;}e[N<<1];
void add(int u,int v){e[++idx]=edge{v,h[u]};h[u]=idx;}
void dfs1(int u,int Fa)
{
	dep[u]=maxd[u]=dep[Fa]+1;fa[u][0]=Fa;
	for(int i=1;i<=20;i++) fa[u][i]=fa[fa[u][i-1]][i-1];
	for(int v,i=h[u];i;i=e[i].nex)
	{
		v=e[i].v;if(v==Fa) continue;dfs1(v,u);
		if(maxd[v]>maxd[u]) maxd[u]=maxd[v],son[u]=v;
	}
}
void dfs2(int u,int t)
{
	top[u]=t;for(int i=1,p=fa[u][0];i<B;i++) sum[u][i]=sum[p][i]+a[u],p=fa[p][0];
	if(u==t)
	{
		for(int i=0,v=u;i<=maxd[u]-dep[u];i++) up[u].push_back(v),v=fa[v][0];
		for(int i=0,v=u;i<=maxd[u]-dep[u];i++) down[u].push_back(v),v=son[v];
	}
	if(son[u]) dfs2(son[u],t);
	for(int i=h[u],v;i;i=e[i].nex)
	{
		v=e[i].v;if(v==fa[u][0]||v==son[u]) continue;dfs2(v,v);
	}
}
int LCA(int x,int y)
{
	if(dep[x]<dep[y]) swap(x,y);
	for(int i=20;i>=0;i--) if(dep[fa[x][i]]>=dep[y]) x=fa[x][i];if(x==y) return x;
	for(int i=20;i>=0;i--) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];return fa[x][0];
}
int get_k(int u,int k)
{
	if(dep[u]<=k) return 0;if(!k) return u;
	u=fa[u][lg[k]];k-=1<<lg[k];k-=dep[u]-dep[top[u]];u=top[u];
	return k>=0?up[u][k]:down[u][-k];
}
bool mem2;
int main()
{
	scanf("%d",&n);lg[0]=-1;for(int i=1;i<=n;i++) scanf("%d",&a[i]),lg[i]=lg[i>>1]+1;
	for(int u,v,i=1;i<n;i++) scanf("%d%d",&u,&v),add(u,v),add(v,u);dfs1(1,0);dfs2(1,1);
	for(int i=1;i<=n;i++) scanf("%d",&b[i]);for(int i=1;i<n;i++) scanf("%d",&c[i]);
	for(int i=1;i<n;i++)
	{
		int x=b[i],y=b[i+1],k=c[i],z=LCA(x,y),res=0;
		if(k>=B)
		{
			for(int u=x;dep[u]>=dep[z];u=get_k(u,k)) res+=a[u];
			for(int u=y;dep[u]>=dep[z];u=get_k(u,k)) res+=a[u];
			if((dep[x]-dep[z])%k==0&&(dep[y]-dep[z])%k==0) res-=a[z];
		}
		else
		{
			res+=sum[x][k]-sum[get_k(x,k*((dep[x]-dep[z])/k+1))][k];
			res+=sum[y][k]-sum[get_k(y,k*((dep[y]-dep[z])/k+1))][k];
			if((dep[x]-dep[z])%k==0&&(dep[y]-dep[z])%k==0) res-=a[z];
		}
		printf("%d\n",res);
	}
	return 0;
}
posted @ 2023-03-17 14:43  曙诚  阅读(45)  评论(0)    收藏  举报