BZOJ 4012 【HNOI2015】 开店

题目链接:开店

  这道题显然一眼树分治,维护点分的结构,在每个点上,对每种年龄到这个点\(u\)以及他在点分树上父亲的距离和建一棵线段树,查询的时候一路往上跳即可。

  但是我太懒了(其实你要说我不会也可以),所以并不想写这种东西。于是,我就只能尝试一下别的方法。

  设一个点\(u\)的年龄为\(y_u\),\(u\)、\(v\)两点之间的距离为\(dis(u,v)\),\(T_u=dis(root,u)\),我们每次要求的式子是:

\begin{aligned} &\sum_{y_x\in [l,r]} dis(x,u)\\ =&\sum_{y_x\in [l,r]}(T_x+T_u-2T_{LCA(u,x)})\end{aligned}

  注意到前面那两项我们是可以通过预处理前缀和\(O(1)\)求出的。于是我们就只需要考虑后面那坨东西怎么求。

  我们可以考虑转化一下思路,转而求每条边的贡献。我们考虑对于一个点\(x\)满足\(y_x\in[l,r]\),那么\(LCA(u,x)\)一直到根的路径都要被计算一次。那么我们就可以对于每个\(y_x\in[l,r]\),把点\(x\)往上跳,途中经过的边标记加\(1\)。那么最后我们再从\(u\)往上跳,每条边的的标记数就是这条边被计算的次数。那么我们就只需要快速维护一个点到根的路径即可。这个可以树链剖分之后用权值线段树解决。

  下面贴代码:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define File(s) freopen(s".in","r",stdin),freopen(s".out","w",stdout)
#define maxn 150010
#define MAXN 6000010
#define point pair<int,int>
#define sa first
#define sb second
 
using namespace std;
typedef long long llg;
 
point s[maxn];
int n,Q,A,a[maxn],da[maxn],ld,dep[maxn],c1[maxn];//c1[x]表示[1,x]出现的次数和
int head[maxn],next[maxn<<1],to[maxn<<1],c[maxn<<1],tt;//邻接表
int fa[maxn],top[maxn],siz[maxn],son[maxn],tc[maxn],wd[maxn],fd[maxn];//树链剖分
int rt[maxn],addv[MAXN],le[MAXN],ri[MAXN],L,R,CO,_1,_2;//主席树
llg ans,c2[maxn],sumv[MAXN],ji;//c2[x]表示∑dep[u](a[u]∈[1,x])
 
int getint(){
	int w=0;bool q=0;
	char c=getchar();
	while((c>'9'||c<'0')&&c!='-') c=getchar();
	if(c=='-') c=getchar(),q=1;
	while(c>='0'&&c<='9') w=w*10+c-'0',c=getchar();
	return q?-w:w;
}
 
void link(int x,int y){
	to[++tt]=y;next[tt]=head[x];head[x]=tt;
	to[++tt]=x;next[tt]=head[y];head[y]=tt;
	c[tt]=c[tt-1]=getint();
}
 
void dfs(int u){
	siz[u]=1; c1[a[u]]++; c2[a[u]]+=dep[u];
	for(int i=head[u],v;v=to[i],i;i=next[i])
		if(!siz[v]){
			dep[v]=dep[u]+c[i];
			fa[v]=u; dfs(v); siz[u]+=siz[v];
			if(siz[v]>siz[son[u]]) son[u]=v;
		}
}
 
void dfs(int u,int ot){
	top[u]=ot; tc[u]=++tt;
	wd[tt]=dep[u]; fd[tt]=dep[fa[u]];
	if(son[u]) dfs(son[u],ot);
	for(int i=head[u],v;v=to[i],i;i=next[i])
		if(!top[v]) dfs(v,v);
}
 
void add(int &u,int l,int r){
	tt++; addv[tt]=addv[u];
	le[tt]=le[u],ri[tt]=ri[u];
	sumv[tt]=sumv[u]; u=tt;
	int mid=(l+r)>>1;
	if(l>=L && r<=R){
		sumv[u]+=wd[r]-fd[l];
		addv[u]++; return;
	}
	if(L<=mid) add(le[u],l,mid);
	if(R>mid) add(ri[u],mid+1,r);
	sumv[u]=sumv[le[u]]+sumv[ri[u]];
	sumv[u]+=(llg)addv[u]*(wd[r]-fd[l]);
}
 
void work(int u,int co){
	while(u){
		L=tc[top[u]],R=tc[u];
		add(rt[co],1,n); u=fa[top[u]];
	}
}
 
void query(int u1,int u2,int l,int r){
	int mid=(l+r)>>1;
	if(l>=L && r<=R){
		ji+=sumv[u2]+(llg)_2*(wd[r]-fd[l]);
		ji-=sumv[u1]+(llg)_1*(wd[r]-fd[l]);
		return;
	}
	_1+=addv[u1]; _2+=addv[u2];
	if(L<=mid) query(le[u1],le[u2],l,mid);
	if(R>mid) query(ri[u1],ri[u2],mid+1,r);
	_1-=addv[u1]; _2-=addv[u2];
}
 
int up(int x){//二分>=x的第一个
	int l=1,r=ld,mid;
	while(l!=r){
		mid=(l+r)>>1;
		if(da[mid]>=x) r=mid;
		else l=mid+1;
	}
	return l;
}
 
int lo(int x){//二分<=x的第一个
	int l=1,r=ld,mid;
	while(l!=r){
		mid=(l+r+1)>>1;
		if(da[mid]<=x) l=mid;
		else r=mid-1;
	}
	return l;
}
 
int main(){
	File("shop");
	n=getint(); Q=getint(); A=getint(); ld=n;
	for(int i=1;i<=n;i++) a[i]=da[i]=getint(); da[++ld]=A+1; da[++ld]=0;
	sort(da+1,da+ld+1); ld=unique(da+1,da+ld+1)-da-1;
	for(int i=1;i<=n;i++) a[i]=up(a[i]),s[i]=make_pair(a[i],i);
	for(int i=1;i<n;i++) link(getint(),getint());
	tt=0; dfs(1); dfs(1,1); sort(s+1,s+n+1); tt=0;
	for(int i=1;i<=ld;i++) c1[i]+=c1[i-1],c2[i]+=c2[i-1];
	for(int i=1;i<=n;i++){
		if(s[i].sa!=s[i-1].sa) rt[s[i].sa]=rt[s[i-1].sa];
		CO=s[i].sa; work(s[i].sb,s[i].sa);
	}
	rt[ld]=rt[ld-1];
	while(Q--){
		int u=getint(),aa=getint(),bb=getint(),l,r; ji=0;
		(aa+=ans%A)%=A; (bb+=ans%A)%=A;
		l=min(aa,bb),r=max(aa,bb); l=up(l); r=lo(r);
		if(l>r) ans=ji=0;
		else{
			ans=c2[r]-c2[l-1]+(llg)(c1[r]-c1[l-1])*dep[u];
			while(u){
				L=tc[top[u]],R=tc[u];
				query(rt[l-1],rt[r],1,n);
				u=fa[top[u]];
			}
			ans-=ji<<1;
		}
		printf("%lld\n",ans);
	}
	return 0;
}
posted @ 2017-01-22 09:36  lcf2000  阅读(268)  评论(2编辑  收藏  举报