#全局平衡二叉树,树链剖分#洛谷 4751 【模板】动态 DP(加强版)
分析
正常的树形dp是 \(f[x][0]+=\max(f[y][0],f[y][1]),f[x][1]+=f[y][0]\)
按照重儿子和轻儿子进行拆分,那么 \(f[x][0]=g[x][0]+\max(f[big[x]][0],f[big[x]][1]),f[x][1]=g[x][1]+f[big[x]][0]\)
那么可以转化为广义矩阵乘法
\[\begin{bmatrix}
f[x][0] \\
f[x][1]
\end{bmatrix}
=
\begin{bmatrix}
g[x][0] & g[x][0] \\
g[x][1] & -\infty
\end{bmatrix}
*
\begin{bmatrix}
f[big[x]][0] \\
f[big[x]][1]
\end{bmatrix}
\]
可以发现单点修改的时候只需要对整个重链进行查询,因此可以对每个重链开一棵线段树就能卡进时限
代码
#include <cstdio>
#include <cctype>
using namespace std;
const int N=1000011,inf=0x3f3f3f3f;
struct node{int y,next;}e[N<<1];
int dfn[N],big[N],Top[N],siz[N],nfd[N],tot,ofn[N],ls[N<<2],rs[N<<2];
int dp[N][2],a[N],dep[N],fat[N],et=1,n,m,as[N],rt[N],cnt;
int iut(){
int ans=0,f=1; char c=getchar();
while (!isdigit(c)) f=(c=='-')?-f:f,c=getchar();
while (isdigit(c)) ans=ans*10+c-48,c=getchar();
return ans*f;
}
inline void print(int ans){
if (ans<0) putchar('-'),ans=-ans;
if (ans>9) print(ans/10);
putchar(ans%10+48);
}
int max(int a,int b){return a>b?a:b;}
struct maix{
int p[2][2];
inline maix operator *(const maix &B)const{
maix C;
C.p[0][0]=max(p[0][0]+B.p[0][0],p[0][1]+B.p[1][0]),
C.p[0][1]=max(p[0][0]+B.p[0][1],p[0][1]+B.p[1][1]),
C.p[1][0]=max(p[1][0]+B.p[0][0],p[1][1]+B.p[1][0]),
C.p[1][1]=max(p[1][0]+B.p[0][1],p[1][1]+B.p[1][1]);
return C;
}
}w[N<<2],A[N];
void build(int &rt,int l,int r){
rt=++cnt;
if (l==r) {w[rt]=A[nfd[l]]; return;}
int mid=(l+r)>>1;
build(ls[rt],l,mid);
build(rs[rt],mid+1,r);
w[rt]=w[ls[rt]]*w[rs[rt]];
}
void update(int rt,int l,int r,int x){
if (l==r) {w[rt]=A[nfd[x]]; return;}
int mid=(l+r)>>1;
if (x<=mid) update(ls[rt],l,mid,x);
else update(rs[rt],mid+1,r,x);
w[rt]=w[ls[rt]]*w[rs[rt]];
}
void dfs1(int x,int fa){
fat[x]=fa,siz[x]=1,dep[x]=dep[fa]+1;
for (int i=as[x],SIZ=-1;i;i=e[i].next)
if (e[i].y!=fa){
dfs1(e[i].y,x),siz[x]+=siz[e[i].y];
if (SIZ<siz[e[i].y]) big[x]=e[i].y,SIZ=siz[e[i].y];
}
}
void dfs2(int x,int linp){
dfn[x]=++tot,nfd[tot]=x,Top[x]=linp,ofn[linp]=tot;
dp[x][0]=0,dp[x][1]=a[x],A[x].p[1][1]=-inf,
A[x].p[0][0]=A[x].p[0][1]=0,A[x].p[1][0]=a[x];
if (!big[x]) return; dfs2(big[x],linp);
dp[x][0]+=max(dp[big[x]][0],dp[big[x]][1]);
dp[x][1]+=dp[big[x]][0];
for (int i=as[x];i;i=e[i].next)
if (e[i].y!=fat[x]&&e[i].y!=big[x]){
dfs2(e[i].y,e[i].y);
int now=max(dp[e[i].y][0],dp[e[i].y][1]);
dp[x][0]+=now,dp[x][1]+=dp[e[i].y][0],
A[x].p[0][0]+=now,A[x].p[1][0]+=dp[e[i].y][0];
}
A[x].p[0][1]=A[x].p[0][0];
}
inline void Update(int x,int z){
for (A[x].p[1][0]+=z-a[x],a[x]=z;x;){
maix B1=w[rt[Top[x]]];
update(rt[Top[x]],dfn[Top[x]],ofn[Top[x]],dfn[x]);
maix B2=w[rt[Top[x]]];
x=fat[Top[x]];
A[x].p[0][0]+=max(B2.p[0][0],B2.p[1][0])-max(B1.p[0][0],B1.p[1][0]),
A[x].p[0][1]=A[x].p[0][0],A[x].p[1][0]+=B2.p[0][0]-B1.p[0][0];
}
}
int main(){
n=iut(); m=iut();
for (int i=1;i<=n;++i) a[i]=iut();
for (int i=1;i<n;++i){
int x=iut(),y=iut();
e[++et]=(node){y,as[x]},as[x]=et;
e[++et]=(node){x,as[y]},as[y]=et;
}
dfs1(1,0),dfs2(1,1);
for (int x=1;x<=n;++x) if (Top[x]==x) build(rt[Top[x]],dfn[Top[x]],ofn[Top[x]]);
for (int i=1,lans=0;i<=m;++i,putchar(10)){
int x=iut()^lans,z=iut(); Update(x,z);
print(lans=max(w[1].p[0][0],w[1].p[1][0]));
}
return 0;
}
分析(全局平衡二叉树)
然而这样复杂度仍然是 log 方的,考虑重链能不能也变成 log 呢,其实是可以的,
不妨对重链按照每个节点轻儿子的大小从中位数分治,轻儿子认父不认子,就得到了全局平衡二叉树,
这样树高是 log 级别的,对这棵辅助树跳father修改即可。
代码
#include <cstdio>
#include <cctype>
#include <vector>
using namespace std;
const int N=1000011,inf=0x3f3f3f3f; struct node{int y,next;}e[N<<1];
int fat[N],siz[N],a[N],f[N][2],g[N][2],light[N],et=1,n,big[N],lights[N],son[N][2],as[N],father[N],gbrt,Q;
int iut(){
int ans=0,f=1; char c=getchar();
while (!isdigit(c)) f=(c=='-')?-f:f,c=getchar();
while (isdigit(c)) ans=ans*10+c-48,c=getchar();
return ans*f;
}
inline void print(int ans){
if (ans<0) putchar('-'),ans=-ans;
if (ans>9) print(ans/10);
putchar(ans%10+48);
}
int max(int a,int b){return a>b?a:b;}
struct maix{
int p[2][2];
inline maix operator *(const maix &B)const{
maix C;
C.p[0][0]=max(p[0][0]+B.p[0][0],p[0][1]+B.p[1][0]),
C.p[0][1]=max(p[0][0]+B.p[0][1],p[0][1]+B.p[1][1]),
C.p[1][0]=max(p[1][0]+B.p[0][0],p[1][1]+B.p[1][0]),
C.p[1][1]=max(p[1][0]+B.p[0][1],p[1][1]+B.p[1][1]);
return C;
}
}w[N],A[N];
void dfs1(int x,int fa){
fat[x]=fa,siz[x]=1;
f[x][0]=0,f[x][1]=a[x];
for (int i=as[x],SIZ=-1;i;i=e[i].next)
if (e[i].y!=fa){
dfs1(e[i].y,x),siz[x]+=siz[e[i].y];
if (SIZ<siz[e[i].y]) big[x]=e[i].y,SIZ=siz[e[i].y];
f[x][0]+=max(f[e[i].y][0],f[e[i].y][1]);
f[x][1]+=f[e[i].y][0];
}
if (big[x]){
light[x]=siz[x]-siz[big[x]];
g[x][0]=f[x][0]-max(f[big[x]][0],f[big[x]][1]);
g[x][1]=f[x][1]-f[big[x]][0];
}else{
light[x]=siz[x];
g[x][0]=f[x][0];
g[x][1]=f[x][1];
}
}
void reset(int x){
A[x].p[0][0]=A[x].p[0][1]=g[x][0];
A[x].p[1][0]=g[x][1],A[x].p[1][1]=-inf;
}
void pup(int x){
if (son[x][0]) w[x]=w[son[x][0]]*A[x];
else w[x]=A[x];
if (son[x][1]) w[x]=w[x]*w[son[x][1]];
}
int build(vector<int>heavy,int l,int r,int fa){
if (l>r) return 0;
lights[r+1]=0;
for (int i=r;i>=l;--i) lights[i]=lights[i+1]+light[heavy[i]];
int mid=l;
for (int i=r;i>l;--i)
if (lights[i]>=lights[l]-lights[i]){
mid=i;
break;
}
int x=heavy[mid];
father[x]=fa;
son[x][0]=build(heavy,l,mid-1,x);
son[x][1]=build(heavy,mid+1,r,x);
reset(x),pup(x);
return x;
}
int dfs2(int x,int fa){
vector<int>heavy;
for (int u=x;u;u=big[u]) heavy.push_back(u);
int rt=build(heavy,0,heavy.size()-1,fa);
for (int u:heavy)
for (int i=as[u];i;i=e[i].next)
if (e[i].y!=fat[u]&&e[i].y!=big[u]) dfs2(e[i].y,u);
return rt;
}
void update(int x,int y){
g[x][1]+=y-a[x],a[x]=y,reset(x);
for (int fa;x;x=fa){
fa=father[x];
if (fa&&son[fa][0]!=x&&son[fa][1]!=x){
maix B1=w[x]; pup(x); maix B2=w[x];
g[fa][0]+=max(B2.p[0][0],B2.p[1][0])-max(B1.p[0][0],B1.p[1][0]),
g[fa][1]+=B2.p[0][0]-B1.p[0][0],reset(fa);
}else pup(x);
}
}
int main(){
n=iut(),Q=iut();
for (int i=1;i<=n;++i) a[i]=iut();
for (int i=1;i<n;++i){
int x=iut(),y=iut();
e[++et]=(node){y,as[x]},as[x]=et;
e[++et]=(node){x,as[y]},as[y]=et;
}
dfs1(1,0);
gbrt=dfs2(1,0);
for (int i=1,lans=0;i<=Q;++i,putchar(10)){
int x=iut()^lans,z=iut(); update(x,z);
print(lans=max(w[gbrt].p[0][0],w[gbrt].p[1][0]));
}
return 0;
}