Solution P9846 [ICPC 2021 Nanjing R] Paimon's Tree
提供一种不用特判且比较好写的方式。
考虑如果我们确定了一条直径是答案会怎么样:显然我们希望直径上的数越大越好,因此遇到一个较小的数我们会将其放到非直径上的边(记为跳过边)。但是发现这类跳过边有限制,而选择直径边则会使限制变宽。
因此考虑 dp。找出直径上某点一共能不经过直径上点到达 \(cnt\) 个其它的点,即 \(cnt\) 等于 \(n\) 减去位于直径上它的左右两点的子树大小。那么如果某次将这个点染黑,则可以跳过边的限制增大了 \(cnt\)。
发现由于直径上黑点是连续的一片,可以区间 dp。记 \(f_{i,j,k}\) 表示直径上 \(i\) 到 \(j\) 染过色,且跳过边有 \(k\) 个的最大权值。这时新选择的边权值为 \(a_{j-i+1+k}\)。预处理出 \(i\) 到 \(j\) 的 \(cnt\) 之和,则转移有:\(f_{i-1,j,k}\gets f_{i,j,k}+a_{j-i+1+k},f_{i,j+1,k}\gets f_{i,j,k}+a_{j-i+1+k},f_{i,j,k}\gets f_{i,j,k-1}\ (k\le\sum cnt)\)。
先枚举直径,再做 dp,复杂度是 \(O(n^5)\) 的。
考虑我们在树上同样做一个类似区间 dp 的东西。此时需要一些转化。原本我们知道哪条是直径,但现在如果仍然记 \(f_{i,j,k}\) 表示 \(i\) 到 \(j\) 路径(包含 \(i,j\))上染过色,则可能的 \(cnt\) 有多个(因为不知道下一个往哪里转移)。因此考虑调整状态,\(f_{i,j,k}\) 表示 \(i\) 到 \(j\) 路径(不含 \(i,j\))染过色的答案。为了避免特判,给每个叶子节点都连一个额外的终止节点,这样若遇到终止节点则代表此时这个路径已经选到头了。
发现转移是从 \(dis_{i,j}\) 小的转移到大的,因此可以提前处理出 dp 状态的顺序然后刷表求。单次转移是 \(O(1)\) 的,且只有 \(O(n^3)\) 个状态。时间复杂度 \(O(n^3)\)。
const int N=305;
int n,a[N];
vector<int>e[N];
ll f[N][N][N];int sz[N][N];
int tot,rt,fa[N],siz[N],dep[N];
il void dfs(int x,int f){
siz[x]=(x<=n),fa[x]=f,dep[x]=dep[f]+1;
for(int y:e[x])if(y!=f)dfs(y,x),siz[x]+=siz[y];
}
il int dis(int x,int y){
int ret=0;
if(dep[x]<dep[y])swap(x,y);
while(dep[x]>dep[y])ret++,x=fa[x];
while(x!=y)ret+=2,x=fa[x],y=fa[y];
return ret;
}
int d[N][N];
vector<pii>node[N];
// 找出某个点除去 a,b 子结点后的子树大小
il int getsiz(int x,int a,int b){
if(fa[x]==b)swap(a,b);
if(fa[x]==a)return siz[x]-siz[b]-1;
else return n-siz[a]-siz[b]-1;
}
il void work(){
tot=n=read()+1,rt=1;forto(i,1,n-1)a[i]=read();
forto(i,1,2*n){
e[i].clear(),node[i].clear();
forto(j,1,2*n)forto(k,0,n)f[i][j][k]=0;
}
int x,y;forto(i,2,n)x=read(),y=read(),e[x].eb(y),e[y].eb(x);
forto(i,1,n){
if(e[i].size()==1)++tot,e[i].eb(tot),e[tot].eb(i);
else rt=i;
}
dfs(rt,0);
int mx=0;
forto(i,1,tot)forto(j,1,tot){
d[i][j]=dis(i,j);
if(d[i][j]>1)node[d[i][j]].eb(mkp(i,j)),mx=max(mx,d[i][j]);
}
for(auto[i,j]:node[2]){
int mid=0;
for(int x:e[i])if(d[j][x]<d[i][j]){mid=x;break;}
sz[i][j]=getsiz(mid,i,j);
}
ll ans=0;
forto(len,2,mx){
for(auto[i,j]:node[len]){
int pi,pj;
for(int x:e[i])if(d[x][j]<d[i][j]){pi=x;break;}
for(int x:e[j])if(d[x][i]<d[i][j]){pj=x;break;}
for(int x:e[i])if(x!=pi)sz[x][j]=sz[i][j]+getsiz(i,x,pi);
for(int x:e[j])if(x!=pj)sz[i][x]=sz[i][j]+getsiz(j,x,pj);
forto(k,0,sz[i][j]){
if(k)f[i][j][k]=max(f[i][j][k],f[i][j][k-1]);
if(i<=n){
for(int x:e[i])if(x!=pi)f[x][j][k]=max(f[x][j][k],f[i][j][k]+a[k+d[i][j]-1]);
}
if(j<=n){
for(int x:e[j])if(x!=pj)f[i][x][k]=max(f[i][x][k],f[i][j][k]+a[k+d[i][j]-1]);
}
}
if(i>n&&j>n)ans=max(ans,f[i][j][sz[i][j]]);
}
}
printf("%lld\n",ans);
}
signed main(){
int t=read();while(t--)work();
return 0;
}