【洛谷】P3591 [POI2015] ODW(根号分治+长链剖分)
题意
给定一棵 \(n\) 个点的树,树上每条边的长度都为 \(1\),第 \(i\) 个点的权值为 \(a_i\)。
Byteasar 想要走遍这整棵树,他会按照某个 \(1\) 到 \(n\) 的全排列 \(b\) 走 \(n-1\) 次,第 \(i\) 次他会从 \(b_i\) 点走到 \(b_{i + 1}\) 点,并且这一次的步伐大小为 \(c_i\)。
对于一次行走,假设起点为 \(x\),终点为 \(y\),步伐为 \(k\),那么Byteasar会从 \(x\) 开始,每步往前走 \(k\) 条边,数据保证了每次行走的距离是 \(k\) 的倍数。
请帮助 Byteasar 统计出每一次行走时经过的所有点的权值和。
\(n \leq 50000\)。
思路
不难发现,本题中的经过的路径点数随着 \(k\) 的增大而减小,于是可以考虑根号分治。
-
当 \(k \leq \sqrt{n}\) 时,记 \(sum_{i,j}\) 为从 \(i\) 节点往根节点走 \(k\) 的大小,所能走到的权值和。这个可以 \(O(n \sqrt{n})\) 预处理得出。每次询问的时候只需特判一下 \(\mathrm{LCA}\) 即可。
-
当 \(k > \sqrt{n}\) 时,只需要暴力往 \(k\) 级祖先跳即可。但是倍增求 \(k\) 级祖先的复杂度要带个 \(\log\),用到长链剖分,可以每次做到 \(O(1)\) 跳到 \(k\) 级祖先。
code:
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
bool mem1;
const int N=5e4+10,B=233;
int dep[N],maxd[N],b[N],c[N],h[N],idx,n,a[N],son[N],top[N],sum[N][B],fa[N][21],lg[N];
vector<int>up[N],down[N];
struct edge{int v,nex;}e[N<<1];
void add(int u,int v){e[++idx]=edge{v,h[u]};h[u]=idx;}
void dfs1(int u,int Fa)
{
dep[u]=maxd[u]=dep[Fa]+1;fa[u][0]=Fa;
for(int i=1;i<=20;i++) fa[u][i]=fa[fa[u][i-1]][i-1];
for(int v,i=h[u];i;i=e[i].nex)
{
v=e[i].v;if(v==Fa) continue;dfs1(v,u);
if(maxd[v]>maxd[u]) maxd[u]=maxd[v],son[u]=v;
}
}
void dfs2(int u,int t)
{
top[u]=t;for(int i=1,p=fa[u][0];i<B;i++) sum[u][i]=sum[p][i]+a[u],p=fa[p][0];
if(u==t)
{
for(int i=0,v=u;i<=maxd[u]-dep[u];i++) up[u].push_back(v),v=fa[v][0];
for(int i=0,v=u;i<=maxd[u]-dep[u];i++) down[u].push_back(v),v=son[v];
}
if(son[u]) dfs2(son[u],t);
for(int i=h[u],v;i;i=e[i].nex)
{
v=e[i].v;if(v==fa[u][0]||v==son[u]) continue;dfs2(v,v);
}
}
int LCA(int x,int y)
{
if(dep[x]<dep[y]) swap(x,y);
for(int i=20;i>=0;i--) if(dep[fa[x][i]]>=dep[y]) x=fa[x][i];if(x==y) return x;
for(int i=20;i>=0;i--) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];return fa[x][0];
}
int get_k(int u,int k)
{
if(dep[u]<=k) return 0;if(!k) return u;
u=fa[u][lg[k]];k-=1<<lg[k];k-=dep[u]-dep[top[u]];u=top[u];
return k>=0?up[u][k]:down[u][-k];
}
bool mem2;
int main()
{
scanf("%d",&n);lg[0]=-1;for(int i=1;i<=n;i++) scanf("%d",&a[i]),lg[i]=lg[i>>1]+1;
for(int u,v,i=1;i<n;i++) scanf("%d%d",&u,&v),add(u,v),add(v,u);dfs1(1,0);dfs2(1,1);
for(int i=1;i<=n;i++) scanf("%d",&b[i]);for(int i=1;i<n;i++) scanf("%d",&c[i]);
for(int i=1;i<n;i++)
{
int x=b[i],y=b[i+1],k=c[i],z=LCA(x,y),res=0;
if(k>=B)
{
for(int u=x;dep[u]>=dep[z];u=get_k(u,k)) res+=a[u];
for(int u=y;dep[u]>=dep[z];u=get_k(u,k)) res+=a[u];
if((dep[x]-dep[z])%k==0&&(dep[y]-dep[z])%k==0) res-=a[z];
}
else
{
res+=sum[x][k]-sum[get_k(x,k*((dep[x]-dep[z])/k+1))][k];
res+=sum[y][k]-sum[get_k(y,k*((dep[y]-dep[z])/k+1))][k];
if((dep[x]-dep[z])%k==0&&(dep[y]-dep[z])%k==0) res-=a[z];
}
printf("%d\n",res);
}
return 0;
}