模拟赛 T2
链接:https://sjzezoj.com/contest/428/problem/2449
代码借鉴了 aqz 大蛇,考场上想到差不多的东西,只是最后我没有想出来具体怎么对中间的点进行统计,所以打了性质 A, 拿到了 40 pts。
记录一下我自己的办法怎么转化成正解。
还是一样的建 Kruskal 重构树。
我们维护一下每个节点子树中有没有指定地点,如果有,那么对于它的子树中的 1 节点,它如果是 lca, 那么这个无疑就是最优的节点。
所以我们记录一下从根到每一个节点对于答案的贡献是多少,并 dfs 出来离它最近的,子树中有制定地点的节点。
对于每个叶子是 0/1 的询问就找到这个节点减去它自身的贡献。
点击查看代码
#include <bits/stdc++.h>
/*愉快抄抄aqz*/
#define int long long
using namespace std;
const int MN=1e6+116;
struct Side{
int u, v, w;
bool operator <(const Side &o)const{
return w<o.w;
}
}side[MN];
int n, k, tot, ans, t[MN];
int lc[MN], rc[MN], val[MN];
int sum1[MN], sum0[MN], val1[MN], val0[MN];
int father[MN], to[MN];
bool vis[MN];
int find(int x){
if(father[x]!=x) father[x]=find(father[x]);
return father[x];
}
void dfs1(int u){
if(!lc[u]&&!rc[u]){
if(t[u]==0) sum0[u]++;
else sum1[u]++;
return;
}else{
dfs1(lc[u]); dfs1(rc[u]);
father[lc[u]]=father[rc[u]]=u;
sum1[u]=sum1[lc[u]]+sum1[rc[u]];
sum0[u]=sum0[lc[u]]+sum0[rc[u]];
vis[u]=(vis[lc[u]]|vis[rc[u]]);
}
}
void dfs2(int u, int top){
val0[u]=val0[u]+sum1[u]*val[u]-sum1[u]*val[father[u]];
val1[u]=val1[u]+sum0[u]*val[u]-sum0[u]*val[father[u]];
top=(vis[u]?u:top);
if(!lc[u]&&!rc[u]){
to[u]=top;
return;
}else{
val0[lc[u]]+=val0[u]; val0[rc[u]]+=val0[u];
val1[lc[u]]+=val1[u]; val1[rc[u]]+=val1[u];
dfs2(lc[u],top); dfs2(rc[u],top);
}
}
int ask1(int u){
if(t[u]!=0) return val1[to[u]];
else return val1[to[u]]-val[to[u]];
}
int ask0(int u){
if(t[u]!=1) return val0[to[u]];
else return val0[to[u]]-val[to[u]];
}
void Read(){
cin>>n>>k; tot=n; for(int i=1; i<=n; ++i) father[i]=i;
for(int i=1; i<=n+n; ++i){
vis[i]=sum1[i]=sum0[i]=val0[i]=val1[i]=0;
to[i]=val[i]=lc[i]=rc[i]=0;
}
for(int i=1; i<=n; ++i) cin>>t[i];
for(int i=1; i<n; ++i){
cin>>side[i].u>>side[i].v>>side[i].w;
}
sort(side+1,side+n);
for(int i=1; i<n; ++i){
++tot; val[tot]=side[i].w; father[tot]=tot;
int u=find(side[i].u), v=find(side[i].v);
lc[tot]=u, rc[tot]=v;
father[u]=father[v]=tot;
}
for(int i=1,x; i<=k; ++i){
cin>>x; vis[x]=true;
}
dfs1(tot); father[tot]=0; dfs2(tot,0);
int ans=0;
for(int i=1; i<=n; ++i) if(t[i]==0) ans+=ask0(i);
cout<<ans<<'\n';
for(int i=1; i<=n; ++i){
if(t[i]==0){
cout<<ans-ask0(i)+ask1(i)<<'\n';
}else{
cout<<ans-ask1(i)+ask0(i)<<'\n';
}
}
}
signed main(){
ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
int T; cin>>T; while(T--){
Read();
}
return 0;
}

浙公网安备 33010602011771号