loj#2050. 「HNOI2016」树
小 A 想做一棵很大的树,但是他手上的材料有限,只好用点小技巧了。
开始,小 A 只有一棵结点数为 \(n\) 的树,结点的编号为 \(1,2,\cdots ,n\) ,其中结点 \(1\) 为根;我们称这颗树为模板树。小 A 决定通过这棵模板树来构建一颗大树。构建过程如下:
- 将模板树复制为初始的大树。
- 以下 (2.1) (2.2) (2.3) 步循环执行 \(m\) 次。
2.1. 选择两个数字 \(a,b\),其中 \(1\leq a\leq n,1\leq b\) 当前大树的结点数。
2.2. 将模板树中以结点 \(a\) 为根的子树复制一遍,挂到大树中结点 \(b\) 的下方 (也就是说,模板树中的结点 \(a\) 为根的子树复制到大树中后,将成为大树中结点 \(b\) 的子树)。
2.3. 将新加入大树的结点按照在模板树中编号的顺序重新编号。例如,假设在进行 (2.2) 步之前大树有 \(l\) 个结点,模板树中以 \(a\) 为根的子树共有 \(c\) 个结点,那么新加入模板树的 \(c\) 个结点在大树中的编号将是 \(l+1,l+2,\cdots,l+c\) ;大树中这 \(c\) 个结点编号的大小顺序和模板树中对应的 \(c\) 个结点的大小顺序是一致的。\(1\leq n,m,q\leq 10^5\)
因为复制的只有 \(n\) 个点的子树,所以可以考虑每个子树缩成一个点 . 然后赋值上新的边权 .
但是有一个问题,就是怎么计算现在的编号在原来节点上编号的问题 . 我一开始采取的是离线下来,通过线段树合并来做 .
虽然这简单的 \(3\) 行,但是我写了 \(300+\) 行,\(9k\) ,我太 sb 了 .
时间复杂度 : \(O(n\log n)\)
空间复杂度 : \(O(n\log n)\)
code
#include<bits/stdc++.h>
using namespace std;
char in[100005];
int iiter=0,llen=0;
inline char get(){
if(iiter==llen)llen=fread(in,1,100000,stdin),iiter=0;
if(llen==0)return EOF;
return in[iiter++];
}
inline long long rd(){
char ch=get();while(ch<'0'||ch>'9')ch=get();
long long res=0;while(ch>='0'&&ch<='9')res=(res<<3)+(res<<1)+ch-'0',ch=get();
return res;
}
inline void pr(long long res){
if(res==0){putchar('0');return;}
static int out[20];int len=0;
while(res)out[len++]=res%10,res/=10;
for(int i=len-1;i>=0;i--)putchar(out[i]+'0');
}
//首先找到所有需要知道在原本的树上的节点编号,用在原本的树上线段树合并
//其次对原本的树树剖,然后预处理准备差分
//计算出现在的树上连接点的位置,计算边的权值,树剖,预处理准备差分
#define lint long long
#define pii pair<int,int>
#define pb push_back
#define mp make_pair
#define fi first
#define se second
const int N=1e5+10;
int n,m,q;
vector<pii>E;
vector<pair<long long,int> >E2;
vector<pair<lint,pii> >vnd;
//左端点,原图上的点,现图上的点
map<long long,pii>Map;
//节点编号,原图上的点,现图上的点
vector<pair<long long,pair<int,int> > >Q[N];
//原图上的点,现图上的点
vector<pair<long long,long long> >qq;
class Tree1{
public:
vector<int>g[N];
int sz[N];
inline void ae(int u,int v){
g[u].pb(v);
g[v].pb(u);
}
void get_sz(int x,int fa){
sz[x]=1;
for(auto to:g[x]){
if(to==fa)continue;
get_sz(to,x);
sz[x]+=sz[to];
}
}
void get_E(int x,int fa){
for(auto to:g[x]){
if(to==fa)continue;
get_E(to,x);
E.pb(mp(x,to));
}
}
class node{public:int ls,rs,cnt,id;}ts[N*30];
int cnt=0,rt[N];
inline int new_node(){
ts[++cnt]=(node){0,0,0,-1};
return cnt;
}
void upd(int x,int l,int r,int pos){
if(l==r){
ts[x].cnt++;
ts[x].id=pos;
return;
}
int mid=(l+r)>>1;
if(pos<=mid){
if(!ts[x].ls)ts[x].ls=new_node();
upd(ts[x].ls,l,mid,pos);
}else{
if(!ts[x].rs)ts[x].rs=new_node();
upd(ts[x].rs,mid+1,r,pos);
}
ts[x].cnt=0;
if(ts[x].ls)ts[x].cnt+=ts[ts[x].ls].cnt;
if(ts[x].rs)ts[x].cnt+=ts[ts[x].rs].cnt;
}
int merge(int a,int b,int l,int r){
if(!a||!b)return a+b;
if(l==r){
ts[a].cnt+=ts[b].cnt;
ts[a].id=l;
return a;
}
int mid=(l+r)>>1;
ts[a].ls=merge(ts[a].ls,ts[b].ls,l,mid);
ts[a].rs=merge(ts[a].rs,ts[b].rs,mid+1,r);
ts[a].cnt=0;
if(ts[a].ls)ts[a].cnt+=ts[ts[a].ls].cnt;
if(ts[a].rs)ts[a].cnt+=ts[ts[a].rs].cnt;
return a;
}
int qry(int x,int l,int r,int k){
if(l==r)return ts[x].id;
int mid=(l+r)>>1;
if(ts[x].ls&&ts[ts[x].ls].cnt>=k)return qry(ts[x].ls,l,mid,k);
if(ts[x].ls)k-=ts[ts[x].ls].cnt;
return qry(ts[x].rs,mid+1,r,k);
}
void transfer(int x,int l,int r){
if(l==r){
cerr<<ts[x].id<<" ";
return;
}
int mid=(l+r)>>1;
if(ts[x].ls)transfer(ts[x].ls,l,mid);
if(ts[x].rs)transfer(ts[x].rs,mid+1,r);
}
void get_id(int x,int fa){
rt[x]=new_node();
upd(rt[x],0,n-1,x);
for(int i=0;i<(int)g[x].size();i++){
int to=g[x][i];
if(to==fa)continue;
get_id(to,x);
merge(rt[x],rt[to],0,n-1);
}
for(int i=0;i<(int)Q[x].size();i++){
// Q 里面第一个是原图上的点,第二个是在此子树中的大小关系
long long tmp=Q[x][i].fi;int id=Q[x][i].se.se,k=Q[x][i].se.fi;
// if(id==633)cerr<<x<<","<<id<<","<<k<<endl;
Map[tmp]=mp(qry(rt[x],0,n-1,k+1),id);
}
}
int fa[N],hd[N],hv[N],pos[N],dep[N];
void get_hv(int x,int f){
fa[x]=f;hv[x]=-1;
for(auto to:g[x]){
if(to==f)continue;
dep[to]=dep[x]+1;
get_hv(to,x);
if(hv[x]==-1||sz[hv[x]]<sz[to]){
hv[x]=to;
}
}
}
void depo(int x,int fa){
pos[x]=cnt++;
if(hv[x]!=-1){
hd[hv[x]]=hd[x];
depo(hv[x],x);
}
for(auto to:g[x]){
if(to==fa||to==hv[x])continue;
hd[to]=to;
depo(to,x);
}
}
void tree_depo(){
cnt=0;
get_hv(0,-1);
depo(0,-1);
}
int lca(int u,int v){
while(hd[u]!=hd[v]){
if(dep[hd[u]]>dep[hd[v]])swap(u,v);
v=fa[hd[v]];
}
return dep[u]<dep[v]?u:v;
}
int get_dis(int u,int v){
int r=lca(u,v);
return dep[u]+dep[v]-2*dep[r];
}
}T1;
class Tree2{
public:
//双倍节点
class edge{public:int to,w,id;}; // 新的节点,连接点原图编号,代价
vector<edge>g[N<<1];
int fid[N<<1],id[N<<1];
inline void ae(int u,int v,int id){
g[u].pb((edge){v,0,id});
}
void pre_dis(int x){
for(auto&e:g[x]){
e.w=T1.dep[e.id]-T1.dep[id[x]]+1;
fid[e.to]=e.id;
pre_dis(e.to);
}
}
int fa[N<<1],hd[N<<1],sz[N<<1],hv[N<<1],dep[N<<1],pos[N<<1],rpos[N<<1],cnt=0;
long long dis[N<<1];
void get_hv(int x){
hv[x]=-1;sz[x]=1;
for(auto e:g[x]){
dep[e.to]=dep[x]+1;
dis[e.to]=dis[x]+e.w;
assert(dis[e.to]>=0);
fa[e.to]=x;
get_hv(e.to);
sz[x]+=sz[e.to];
if(hv[x]==-1||sz[e.to]>sz[hv[x]])hv[x]=e.to;
}
}
void depo(int x){
pos[x]=cnt;rpos[cnt]=x;cnt++;
if(hv[x]!=-1){
hd[hv[x]]=hd[x];
depo(hv[x]);
}
for(auto e:g[x]){
if(e.to==hv[x])continue;
hd[e.to]=e.to;
depo(e.to);
}
}
void tree_depo(){
fa[0]=-1;
get_hv(0);
depo(0);
}
int lca(int u,int v){
while(hd[u]!=hd[v]){
if(dep[hd[u]]>dep[hd[v]])swap(u,v);
v=fa[hd[v]];
}
return dep[u]>dep[v]?v:u;
}
int get_fid(int x,int t){
int lst=-1;
while(hd[x]!=hd[t]){
lst=fid[hd[x]];
x=fa[hd[x]];
}
if(x!=t)lst=fid[rpos[pos[t]+1]];
return lst;
}
}T2;
long long get_ans(int id1,int k1,int id2,int k2,int t){
int r=T2.lca(id1,id2);
long long res=0;
if(r==id2)swap(id1,id2),swap(k1,k2);
if(r==id1){
int fid1=k1,fid2=T2.get_fid(id2,r);
res+=T1.dep[k2]-T1.dep[T2.id[id2]];
res+=T2.dis[id2]-T2.dis[r];
res-=T1.dep[fid2]-T1.dep[T2.id[r]];
res+=T1.get_dis(fid1,fid2);
return res;
}
res+=T1.dep[k1]-T1.dep[T2.id[id1]];
res+=T1.dep[k2]-T1.dep[T2.id[id2]];
int fid1=T2.get_fid(id1,r),fid2=T2.get_fid(id2,r);
res+=T2.dis[id1]-T2.dis[r];
res-=T1.dep[fid1]-T1.dep[T2.id[r]];
res+=T2.dis[id2]-T2.dis[r];
res-=T1.dep[fid2]-T1.dep[T2.id[r]];
res+=T1.get_dis(fid1,fid2);
if(t==101)cerr<<res<<endl;
return res;
}
inline int get_id1(long long val){ // 找到原图上的子树
int id=upper_bound(vnd.begin(),vnd.end(),mp(val+1,mp(-1,-1)))-vnd.begin()-1;
return vnd[id].se.fi;
}
inline int get_id2(long long val){ // 找到现图上的点
int id=upper_bound(vnd.begin(),vnd.end(),mp(val+1,mp(-1,-1)))-vnd.begin()-1;
return vnd[id].se.se;
}
inline int get_id3(long long val){ // 找到原图上的大小关系
int id=upper_bound(vnd.begin(),vnd.end(),mp(val+1ll,mp(-1,-1)))-vnd.begin()-1;
return val-vnd[id].fi;
}
int main(){
n=rd();m=rd();q=rd();
for(int i=0;i<n-1;i++){
int u=rd()-1,v=rd()-1;
T1.ae(u,v);
}
T1.get_sz(0,-1);T1.get_E(0,-1);
for(int i=0;i<n-1;i++){
int u=E[i].fi,v=E[i].se;
T2.ae(u,v,u);
}
for(int i=0;i<n;i++)T2.id[i]=i;
long long tmp=n;
for(int i=0;i<n;i++)vnd.pb(mp(i,mp(i,i))),Map[i]=mp(i,i);
for(int i=0;i<m;i++){
int x=rd()-1;long long to=rd()-1;
T2.id[i+n]=x;E2.pb(mp(to,i+n));
if(to>=n)Q[get_id1(to)].pb(mp(to,mp(get_id3(to),get_id2(to))));
vnd.pb(mp(tmp,mp(x,n+i)));
tmp+=T1.sz[x];
}
for(int i=0;i<q;i++){
long long u=rd()-1,v=rd()-1;
if(u>=n)Q[get_id1(u)].pb(mp(u,mp(get_id3(u),get_id2(u))));
if(v>=n)Q[get_id1(v)].pb(mp(v,mp(get_id3(v),get_id2(v))));
qq.pb(mp(u,v));
}
T1.get_id(0,-1);
for(int i=0;i<(int)E2.size();i++){
long long u=E2[i].fi;int v=E2[i].se;
T2.ae(Map[u].se,v,Map[u].fi);
}
T1.tree_depo();
T2.pre_dis(0);
T2.tree_depo();
T2.fid[0]=-1;
for(int i=0;i<q;i++){
long long u=qq[i].fi,v=qq[i].se;
int id1=Map[u].fi,id2=Map[v].fi,k1=Map[u].se,k2=Map[v].se;
if(k1==k2)pr(T1.get_dis(id1,id2));
else pr(get_ans(k1,id1,k2,id2,i));
putchar('\n');
}
return 0;
}
观察他人的程序后发现,这个子树中编号的问题可以变成 \(dfs\) 序上区间第 \(k\) 大问题,可用可持久化线段树来在线解决 .
然后其他的好像也没有什么区别,但是我为什么写的这么多呢?sbsb 了.
code
#pragma GCC optimize(3)
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
template<typename T> void read(T &x) {
x = 0; bool f = false; char ch = getchar();
for (; !isdigit(ch); ch = getchar()) f ^= (ch == '-');
for (; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + (ch ^ 48);
if (f) x = -x;
}
template<typename T> void write(T x) {
if (x < 0) x = -x, putchar('-');
if (x > 9) write(x / 10);
putchar((x % 10) ^ 48);
}
const int maxn = 1e5 + 10;
int n, cnt;
int sz[maxn], dfn[maxn], mp[maxn], dep1[maxn], fa1[maxn][18];
vector<int> g[maxn];
void dfs1(int u, int pre) {
sz[u] = 1, dfn[u] = ++cnt, mp[cnt] = u;
dep1[u] = dep1[pre] + 1, fa1[u][0] = pre;
for (int i = 1; i < 18; ++i) fa1[u][i] = fa1[fa1[u][i - 1]][i - 1];
for (auto v : g[u]) if (v != pre) dfs1(v, u), sz[u] += sz[v];
}
int lca1(int u, int v) {
if (dep1[u] < dep1[v]) swap(u, v);
for (int i = 17; i >= 0; --i) if (dep1[fa1[u][i]] >= dep1[v]) u = fa1[u][i];
if (u == v) return u;
for (int i = 17; i >= 0; --i) if (fa1[u][i] != fa1[v][i]) u = fa1[u][i], v = fa1[v][i];
return fa1[u][0];
}
int tot;
int rt[maxn], ls[maxn * 18], rs[maxn * 18], sum[maxn * 18];
void update(int &x, int y, int l, int r, int q) {
x = ++tot, ls[x] = ls[y], rs[x] = rs[y], sum[x] = sum[y] + 1;
if (l == r) return;
int mid = (l + r) >> 1;
if (q <= mid) {
update(ls[x], ls[y], l, mid, q);
} else {
update(rs[x], rs[y], mid + 1, r, q);
}
}
int query(int x, int y, int l, int r, int k) {
if (l == r) return l;
int mid = (l + r) >> 1, tmp = sum[ls[x]] - sum[ls[y]];
if (k <= tmp) {
return query(ls[x], ls[y], l, mid, k);
} else {
return query(rs[x], rs[y], mid + 1, r, k - tmp);
}
}
int m, q;
int nd[maxn], lk[maxn], dep2[maxn], fa2[maxn][18];
ll mx[maxn], dis[maxn];
int lca2(int u, int v) {
if (dep2[u] < dep2[v]) swap(u, v);
for (int i = 17; i >= 0; --i) if (dep2[fa2[u][i]] >= dep2[v]) u = fa2[u][i];
if (u == v) return u;
for (int i = 17; i >= 0; --i) if (fa2[u][i] != fa2[v][i]) u = fa2[u][i], v = fa2[v][i];
return fa2[u][0];
}
signed main() {
read(n), read(m), read(q);
for (int i = 1, u, v; i < n; ++i) {
read(u), read(v);
g[u].emplace_back(v);
g[v].emplace_back(u);
}
dfs1(1, 0);
for (int i = 1; i <= n; ++i) update(rt[i], rt[i - 1], 1, n, mp[i]);
mx[1] = n, nd[1] = dep2[1] = 1;
ll x;
for (int i = 2; i <= m + 1; ++i) {
read(nd[i]), read(x);
int t = lower_bound(mx + 1, mx + i, x) - mx;
lk[i] = query(rt[dfn[nd[t]] + sz[nd[t]] - 1], rt[dfn[nd[t]] - 1], 1, n, x - mx[t - 1]);
dep2[i] = dep2[t] + 1, fa2[i][0] = t, dis[i] = dis[t] + dep1[lk[i]] - dep1[nd[t]] + 1;
for (int j = 1; j < 18; ++j) fa2[i][j] = fa2[fa2[i][j - 1]][j - 1];
mx[i] = mx[i - 1] + sz[nd[i]];
}
while (q--) {
ll u, v;
int tu, tv;
read(u), read(v);
tu = lower_bound(mx + 1, mx + m + 2, u) - mx;
tv = lower_bound(mx + 1, mx + m + 2, v) - mx;
u = query(rt[dfn[nd[tu]] + sz[nd[tu]] - 1], rt[dfn[nd[tu]] - 1], 1, n, u - mx[tu - 1]);
v = query(rt[dfn[nd[tv]] + sz[nd[tv]] - 1], rt[dfn[nd[tv]] - 1], 1, n, v - mx[tv - 1]);
if (tu == tv) {
write(dep1[u] + dep1[v] - dep1[lca1(u, v)] * 2), putchar('\n');
continue;
}
if (dep2[tu] > dep2[tv]) swap(tu, tv), swap(u, v);
int t = lca2(tu, tv);
ll ans = dep1[u] - dep1[nd[tu]] + dep1[v] - dep1[nd[tv]] + dis[tu] + dis[tv] - dis[t] * 2;
if (t == tu) {
for (int i = 17; i >= 0; --i) if (dep2[fa2[tv][i]] > dep2[tu]) tv = fa2[tv][i];
ans -= (dep1[lca1(u, lk[tv])] - dep1[nd[tu]]) * 2;
} else {
for (int i = 17; i >= 0; --i) if (dep2[fa2[tu][i]] > dep2[t]) tu = fa2[tu][i];
for (int i = 17; i >= 0; --i) if (dep2[fa2[tv][i]] > dep2[t]) tv = fa2[tv][i];
ans -= (dep1[lca1(lk[tu], lk[tv])] - dep1[nd[t]]) * 2;
}
write(ans), putchar('\n');
}
return 0;
}

浙公网安备 33010602011771号