BSOJ6327【10.17题目】道路road
题目
分析
神仙倍增题。
首先观察到题目有 \(a_i\le 1\) 的部分分,启示我们倍增维护当前点往上 \(2^i\) 走1的个数。
询问就是直接倍增就好了。
那么考虑没有这个限制,显然可以对每一个二进制位都这样维护一下,于是复杂度 \(O(n\log V\log n)\)
结果出题人没卡掉这个复杂度。
考虑继续优化,发现其实我们不需要对于每一个二进制位都这样记,而是可以直接记答案:
\(dpu[i][j]\) 表示 \(i\) 到 \(i\) 的 \(2^i\) 父亲的儿子,这条路径的答案。
\(dpd[i][j]\) 表示 \(i\) 的 \(2^i\) 父亲的儿子 到 \(i\),这条路径的答案。
然后考虑询问,我们可以一边倍增跳一边询问。
这样显然会出问题,因为我们没有算之前跳过的步数的对应位。
所以我们需要在倍增加的同时还要加上变化的量,具体可以见代码 \(Queryup\) 和 \(Querydown\) 函数。
然后注意对于下降的那一段我们进行了一个差分。
时间复杂度 \(O(n\log n)\)
代码
数组的i,j反了是因为我在卡最优解((
#pragma GCC optimize(2)
#include<bits/stdc++.h>
using namespace std;
const int V=2e7+5;
char rbuf[V],obuf[V];
int pt=-1,pt1=-1;
#define getchar() rbuf[++pt]
#define putchar(x) obuf[++pt1]=x
template<typename T>
inline void read(T &x){
x=0;bool f=false;char ch=getchar();
while(!isdigit(ch)){f|=ch=='-';ch=getchar();}
while(isdigit(ch)){x=x*10+(ch^48);ch=getchar();}
x=f?-x:x;
return;
}
template<typename T>
void write(T x){
if(x<0) x=-x,putchar('-');
if(x>9) write(x/10);
putchar(x%10^48);
return ;
}
#define ll long long
const int N=3e5+5,M=2e5+5,MOD=1e9+7,INF=1e9+7;
ll Ans;
int n,m,a[N];
int head[N],to[N<<1],nex[N<<1],idx;
inline void add(int u,int v){
nex[++idx]=head[u];
to[idx]=v;
head[u]=idx;
return ;
}
int dep[N],siz[N],son[N],top[N];
int fa[21][N],cnt[21][N];
ll dpu[21][N],dpd[21][N];
void dfs1(int x,int f){
dep[x]=dep[f]+1,siz[x]=1,fa[0][x]=f;
for(int i=0;i<=20;i++) cnt[i][x]=cnt[i][f]+(!(a[x]&(1<<i)));//统计i到根节点的前缀:第i位是0的数的个数
for(int i=head[x];i;i=nex[i]){
int y=to[i];
if(y==f) continue;
dfs1(y,x);
siz[x]+=siz[y];
if(siz[y]>siz[son[x]]) son[x]=y;
}
return ;
}
void dfs2(int x){
if(x==son[fa[0][x]]) top[x]=top[fa[0][x]];
else top[x]=x;
if(son[x]) dfs2(son[x]);
for(int i=head[x];i;i=nex[i]){
int y=to[i];
if(y==fa[0][x]||y==son[x]) continue;
dfs2(y);
}
return ;
}
inline int QueryLca(int u,int v){
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]]) swap(u,v);
u=fa[0][top[u]];
}
return dep[u]<dep[v]?u:v;
}
inline ll Queryup(int x,int y){
int len=dep[x]-dep[y],now=x;
ll res=0;
for(int i=20;i>=0;i--){
if(len&(1<<i)){
res+=dpu[i][now];
now=fa[i][now];
res+=1ll*(1<<i)*(cnt[i][now]-cnt[i][y]);
}
}
// cout<<res<<"!"<<endl;
return res;
}
inline ll Querydown(int x,int len){//为了方便才用len不是y
ll res=0;int now=x;
for(int i=0;i<=20;i++){
if(len&(1<<i)){
res+=dpd[i][now];
res+=1ll*(1<<i)*(cnt[i][x]-cnt[i][now]);
now=fa[i][now];
}
}
return res;
}
signed main(){
// system("fc tree.out ex_tree1.out");
// freopen("tree.in","r",stdin);
// freopen("tree.out","w",stdout);
fread(rbuf,1,V-5,stdin);
read(n),read(m);
for(int i=1;i<=n;i++) read(a[i]),dpu[0][i]=dpd[0][i]=a[i];
for(int i=1,u,v;i<n;i++) read(u),read(v),add(u,v),add(v,u);
dfs1(1,0),dfs2(1);
for(int j=1;j<=20;j++){//为什么是20,因为最多或上第19位
for(int i=1;i<=n;i++){
fa[j][i]=fa[j-1][fa[j-1][i]];//fa数组
int x=fa[j-1][i],y=fa[j][i];
dpu[j][i]=dpu[j-1][i]+dpu[j-1][x]+1ll*(1<<(j-1))*(cnt[j-1][x]-cnt[j-1][y]);//上升段,表示i到i的2^j的祖先的儿子的答案
dpd[j][i]=dpd[j-1][i]+dpd[j-1][x]+1ll*(1<<(j-1))*(cnt[j-1][i]-cnt[j-1][x]);//下降段,表示i到i的2^j的祖先的儿子的答案
}
}
while(m--){
int x,y;
read(x),read(y);
int lca=QueryLca(x,y),len=dep[x]+dep[y]-dep[lca]-dep[lca];
Ans=Queryup(x,fa[0][lca])+Querydown(y,len+1)-Querydown(lca,len+1-(dep[y]-dep[lca]));//注意这里下降段的差分
write(Ans),putchar('\n');
}
fwrite(obuf,1,pt1+1,stdout);
return 0;
}
/*
5 2
4 3 2 5 3
1 2
1 3
3 4
3 5
2 5
3 4
*/