BZOJ1906树上的蚂蚁&BZOJ3700发展城市——RMQ求LCA+树链的交
题目描述
众所周知,Hzwer学长是一名高富帅,他打算投入巨资发展一些小城市。
Hzwer打算在城市中开N个宾馆,由于Hzwer非常壕,所以宾馆必须建在空中,但是这样就必须建立宾馆之间的连接通道。机智的Hzwer在宾馆中修建了N-1条隧道,也就是说,宾馆和隧道形成了一个树形结构。
Hzwer有时候会花一天时间去视察某个城市,当来到一个城市之后,Hzwer会分析这些宾馆的顾客情况。对于每个顾客,Hzwer用三个数值描述他:(S, T, V)表示该顾客这天想要从宾馆S走到宾馆T,他的速度是V。
Hzwer需要做一些收集一些数据,这样他就可以规划他接下来的投资。
其中有一项数据就是收集所有顾客可能的碰面次数。
每天清晨,顾客同时从S出发以V的速度前往T(注意S可能等于T),当到达了宾馆T的时候,顾客显然要找个房间住下,那么别的顾客再经过这里就不会碰面了。特别的,两个顾客同时到达一个宾馆是可以碰面的。同样,两个顾客同时从某宾馆出发也会碰面。
输入
第一行一个正整数T(1<=T<=20),表示Hzwer发展了T个城市,并且在这T个城市分别视察一次。
对于每个T,第一行有一个正整数N(1<=N<=10^5)表示Hzwer在这个城市开了N个宾馆。
接下来N-1行,每行三个整数X,Y,Z表示宾馆X和宾馆Y之间有一条长度为Z的隧道
再接下来一行M表示这天顾客的数量。
紧跟着M行每行三个整数(S, T, V)表示该顾客会从宾馆S走到宾馆T,速度为v
输出
对于每个T,输出一行,表示顾客的碰面次数。
样例输入
3
1 2 1
2 3 1
3
1 3 2
3 1 1
1 2 3
样例输出
0
提示
【数据规模】
1<=T<=20 1<=N<=10^5 0<=M<=10^3 1<=V<=10^6 1<=Z<=10^3
这题细节好多啊,蒟蒻的我调了一下午。
考虑到m的范围比较小,因此可以两两枚举判断是否相遇。
对于两个路径,如果能够相遇,相遇点一定在两个路径的交路径上。
如何求树上路径交?
对于两个路径A(a.u,a.v)与B(b.u,b.v)求出lca(a.u,b.u),lca(a.v,b.v),lca(a.v,b.u),lca(a.u,b.v)
去掉这四个点中不在A或B路径上的点,再去重后按dfs序排序,取后两个(如果只有一个说明路径只交于一点)就是交路径的两个端点
判断出两个路径起点先到达的交路径的端点是否是同一个,如果是就说明两个顾客是同向运动,反之则是相向运动。
如果两顾客是同向运动:只要先进入交路径的顾客后走出交路径就一定相遇。
如果两顾客是相向运动:分别求出两顾客进入和走出交路径的时间,判断只要两时间段有交集就能相遇,因为除法较慢,所以转成交叉相乘判断。
在判断和求路径过程中多次求lca,用O(logn)的方法求显然会TLE,在这里采用RMQ求lca:
在dfs时求出欧拉遍历序(就是遍历到一个点存一次)及每个点第一次被遍历的位置
对于x,y两点的lca就是欧拉序上两点第一次被遍历位置之间深度最小的点,用ST表即可O(1)查询
这道题有点卡常,注意涉及到乘速度时可能会爆longlong。
#include<cmath> #include<cstdio> #include<cstring> #include<iostream> #include<algorithm> #define ll long long using namespace std; inline char _read() { static char buf[100000],*p1=buf,*p2=buf; return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++; } inline int read() { int x=0,f=1;char ch=_read(); while(ch>'9'||ch<'0'){if(ch=='-')f=-1;ch=_read();} while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+ch-'0';ch=_read();} return x*f; } int T,n,m; int head[100010]; int s[100010]; int to[200010]; int next[200010]; int val[200010]; int d[100010]; int dep[100010]; int f[200010][18]; int g[200010][18]; int tot; int num; int x,y,z; int ans; int p[5]; int cnt; int b[200010]; struct miku { int u,v,w; }a[1010]; inline void add(int x,int y,int z) { tot++; next[tot]=head[x]; head[x]=tot; to[tot]=y; val[tot]=z; } inline void dfs(int x,int fa) { d[x]=d[fa]+1; s[x]=++num; f[num][0]=d[x]; g[num][0]=x; for(int i=head[x];i;i=next[i]) { if(to[i]!=fa) { dep[to[i]]=dep[x]+val[i]; dfs(to[i],x); f[++num][0]=d[x]; g[num][0]=x; } } } inline void ST() { for(int j=1;j<=17;j++) { for(int i=1;i<=num;i++) { if(i+(1<<j)-1>num) { break; } if(f[i][j-1]<f[i+(1<<(j-1))][j-1]) { f[i][j]=f[i][j-1]; g[i][j]=g[i][j-1]; } else { f[i][j]=f[i+(1<<(j-1))][j-1]; g[i][j]=g[i+(1<<(j-1))][j-1]; } } } } inline int lca(int x,int y) { x=s[x]; y=s[y]; if(x>y) { swap(x,y); } int len=b[y-x+1]; if(f[x][len]<f[y-(1<<len)+1][len]) { return g[x][len]; } else { return g[y-(1<<len)+1][len]; } } inline bool find(int anc,int x,int y) { int fx=lca(a[x].u,a[x].v); int fy=lca(a[y].u,a[y].v); if(lca(fx,anc)!=fx||lca(fy,anc)!=fy) { return false; } if(fx!=lca(fx,a[x].u)&&fx!=lca(fx,a[x].v)) { return false; } if(fy!=lca(fy,a[y].u)&&fy!=lca(fy,a[y].v)) { return false; } return true; } inline int dis(int x,int y) { int anc=lca(x,y); return dep[x]+dep[y]-2*dep[anc]; } inline bool cmp(int x,int y) { return s[x]<s[y]; } inline bool cpr(ll a,ll b,ll c) { if(a<=b&&b<=c) { return 1; } else { return 0; } } inline int check(int x,int y) { if(a[x].u==a[y].u) { return 1; } int res; cnt=0; res=lca(a[x].u,a[y].u); if(find(res,x,y)){p[++cnt]=res;} res=lca(a[x].v,a[y].v); if(find(res,x,y)){p[++cnt]=res;} res=lca(a[x].u,a[y].v); if(find(res,x,y)){p[++cnt]=res;} res=lca(a[y].u,a[x].v); if(find(res,x,y)){p[++cnt]=res;} if(cnt==0) { return 0; } sort(p+1,p+1+cnt,cmp); cnt=unique(p+1,p+1+cnt)-p-1; if(cnt==1) { if(1ll*dis(a[x].u,p[1])*a[y].w==1ll*dis(a[y].u,p[1])*a[x].w) { return 1; } else { return false; } } int st=p[cnt]; int ed=p[cnt-1]; int A1,A2,B1,B2; ll a1,a2,b1,b2; if(dis(a[x].u,st)<dis(a[x].u,ed)) { A1=st; A2=ed; } else { A1=ed; A2=st; } if(dis(a[y].u,st)<dis(a[y].u,ed)) { B1=st; B2=ed; } else { B1=ed; B2=st; } a1=1ll*dis(a[x].u,A1)*a[y].w; a2=1ll*dis(a[x].u,A2)*a[y].w; b1=1ll*dis(a[y].u,B1)*a[x].w; b2=1ll*dis(a[y].u,B2)*a[x].w; if(A1==B1) { if(a1==b1) { return 1; } if(a1<b1) { return b2<=a2; } else { return a2<=b2; } } else { if(cpr(a1,b1,a2))return 1; if(cpr(a1,b2,a2))return 1; if(cpr(b1,a1,b2))return 1; if(cpr(b1,a2,b2))return 1; return 0; } } int main() { T=read(); b[0]=-1; for(int i=1;i<=200010;i++) { b[i]=b[i>>1]+1; } while(T--) { memset(head,0,sizeof(head)); num=0; tot=0; ans=0; n=read(); for(int i=1;i<n;i++) { x=read(); y=read(); z=read(); add(x,y,z); add(y,x,z); } dfs(1,0); ST(); m=read(); for(int i=1;i<=m;i++) { a[i].u=read(); a[i].v=read(); a[i].w=read(); } for(int i=1;i<=m;i++) { for(int j=i+1;j<=m;j++) { ans+=check(i,j); } } printf("%d\n",ans); } }