树链剖分
用途:一棵树上求一条链上的点的权值和(或最大最小值),并且更改一个点(或边)的权值。
基本思路:破树成链(这个词是我乱造的)
实现方法:轻重边思想
名词解释:
名词 | 解释 |
---|---|
重边 | 印点出来的(每个非叶子点都有且仅有一条重边) |
轻边 | 除了重边以外的所有边 |
重链 | 一些重边连成一条链,这条链就是重链(单独一个点也算重链) |
重儿子 | 一个节点的重边连向的儿子 |
轻儿子 | 一个节点除了重儿子以外的所有儿子 |
以 |
|
节点 |
|
节点 |
|
节点 |
|
节点 |
|
节点 |
画个图:
另外,在实际应用中,通常使用这两条性质:
1. 每个点的儿子中
2. 在dfs时,先遍历重儿子,再遍历轻儿子。
其中第一条是时间复杂度保证(不会证明。。。),第二条是最关键思想的保证。
那么有了这些有什么用?
方法:把一棵树的点“搬”到一棵线段树上去。
怎么搬?
按照dfs序来搬。
现在你应该明白轻重边的意义了吧。
如果没有,那么让我们再看看图吧。
在这个图中,1-2、2-4、3-6都是重边,那么1-6的dfs序分别为:1,2,5,3,4,6;
那么线段树存放的顺序是:1,2,4,5,3,6。
这棵树的重链分别为:1-4、5-5和3-6。
如果看到这里还没明白,请再学习小学数学找规律题。
这对减少时间复杂度有什么帮助呢?
回到那张图,看一看4-6的路径以及这条路上的重链的个数。
如果在这棵树上dfs,需要走的次数是4;而每次走一条重链和一条轻边,那么只需要走两次。
是不是发现了什么?
是不是和倍增求LCA很像?(如果你不懂倍增求LCA,那么请百度一下)
是的,是和倍增求LCA相似,只是倍增向根跳的距离为
其实树链剖分也可以求LCA。
那么好了,应该能够理解了吧。
代码
//[ZJOI2008]数的统计count
#include <cstdio>
#include <algorithm>
const int maxn=30000;
const int inf=1000000000;
int n;
struct sigment_tree
//线段树模板
{
private:int maxnum[(maxn<<2)+10],sumnum[(maxn<<2)+10];
private:int updata(int now)
{
maxnum[now]=std::max(maxnum[now<<1],maxnum[now<<1|1]);
sumnum[now]=sumnum[now<<1]+sumnum[now<<1|1];
return 0;
}
public:int build(int now,int left,int right)
{
if(left==right)
{
maxnum[now]=0;
sumnum[now]=0;
return 0;
}
int mid=(left+right)>>1;
build(now<<1,left,mid);
build(now<<1|1,mid+1,right);
updata(now);
return 0;
}
public:int modify(int now,int left,int right,int findnum,int changeval)
{
if((findnum<left)||(findnum>right))
{
return 0;
}
if(left==right)
{
maxnum[now]=changeval;
sumnum[now]=changeval;
return 0;
}
int mid=(left+right)>>1;
modify(now<<1,left,mid,findnum,changeval);
modify(now<<1|1,mid+1,right,findnum,changeval);
updata(now);
return 0;
}
public:int askmax(int now,int left,int right,int askl,int askr)
{
if((askr<left)||(askl>right))
{
return -inf;
}
if((askl<=left)&&(askr>=right))
{
return maxnum[now];
}
int mid=(left+right)>>1;
return std::max(askmax(now<<1,left,mid,askl,askr),askmax(now<<1|1,mid+1,right,askl,askr));
}
public:int asksum(int now,int left,int right,int askl,int askr)
{
if((askr<left)||(askl>right))
{
return 0;
}
if((askl<=left)&&(askr>=right))
{
return sumnum[now];
}
int mid=(left+right)>>1;
return asksum(now<<1,left,mid,askl,askr)+asksum(now<<1|1,mid+1,right,askl,askr);
}
};
struct tree
{
private:int fa[maxn+10],wson[maxn+10],top[maxn+10],dfn[maxn+10],cnt,deep[maxn+10],size[maxn+10];
private:int pre[(maxn<<1)+10],now[maxn+10],son[(maxn<<1)+10],tot;
private:sigment_tree st;
public:int ins(int a,int b)
//将a和b连一条有向边
{
tot++;
pre[tot]=now[a];
now[a]=tot;
son[tot]=b;
return 0;
}
public:int first_dfs(int u,int father)
//首次dfs,将deep、fa、size、wson求出来,为第二次dfs做准备
{
deep[u]=deep[father]+1;
fa[u]=father;
size[u]=1;
wson[u]=0;
int j=now[u];
while(j)
{
int v=son[j];
if(v!=father)
{
first_dfs(v,u);
size[u]+=size[v];
if((!wson[u])||(size[v]>size[wson[u]]))
{
wson[u]=v;
}
}
j=pre[j];
}
return 0;
}
public:int second_dfs(int u,int father,int topfather)
//第二次dfs,将dfn和top求出来
{
cnt++;
dfn[u]=cnt;
top[u]=topfather;
if(wson[u])
//重儿子的top就是当前节点的top
{
second_dfs(wson[u],u,topfather);
}
int j=now[u];
while(j)
{
int v=son[j];
if((v!=father)&&(v!=wson[u]))
{
second_dfs(v,u,v);
}
j=pre[j];
}
return 0;
}
public:int change(int pos,int val)
//单点修改
{
st.modify(1,1,n,dfn[pos],val);
//直接修改线段树上这个节点就好
return 0;
}
public:int askmax(int x,int y)
//一条路径上求max值
{
int res=-inf;
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]])
{
std::swap(x,y);
}
res=std::max(res,st.askmax(1,1,n,dfn[top[x]],dfn[x]));
x=fa[top[x]];
}
if(deep[x]>deep[y])
{
std::swap(x,y);
}
//top值相同,那么x和y在同一条重链上
res=std::max(res,st.askmax(1,1,n,dfn[x],dfn[y]));
return res;
}
public:int asksum(int x,int y)
//一条路径上求sum值
{
int res=0;
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]])
{
std::swap(x,y);
}
res+=st.asksum(1,1,n,dfn[top[x]],dfn[x]);
x=fa[top[x]];
}
if(deep[x]>deep[y])
{
std::swap(x,y);
}
res+=st.asksum(1,1,n,dfn[x],dfn[y]);
return res;
}
};
tree t;
int m;
int main()
{
scanf("%d",&n);
for(int i=1; i<n; i++)
{
int a,b;
scanf("%d%d",&a,&b);
t.ins(a,b);
t.ins(b,a);
}
t.first_dfs(1,0);
t.second_dfs(1,0,1);
for(int i=1; i<=n; i++)
{
int a;
scanf("%d",&a);
t.change(i,a);
}
scanf("%d",&m);
while(m--)
{
char s[7];
int a,b;
scanf("%s%d%d",s,&a,&b);
if(s[1]=='H')
{
t.change(a,b);
}
if(s[1]=='S')
{
printf("%d\n",t.asksum(a,b));
}
if(s[1]=='M')
{
printf("%d\n",t.askmax(a,b));
}
}
return 0;
}