牛客练习赛60E

利用树上DSU来统计每一个点的答案
#include<bits/stdc++.h>
#define forn(i, s, t) for (int i = s ; i < (int)t ; i++)
#define fi first
#define se second
#define all(x) x.begin(),x.end()
#define pf2(x,y) printf("%d %d\n",x,y)
#define pf(x) printf("%d\n",x)
#define each(x) for(auto it:x) cout<<it<<endl;
#define pi pair<int,int>
#define pb push_back
#define sc(x) scanf("%d",&x)
#define sc2(x,y) scanf("%d%d",&x,&y)
#define pf(x) printf("%d\n",x)
#define pf2(x,y) printf("%d %d\n",x,y)
#define mem(a,x) memset(a,x,sizeof(a))
#define copy(b,a) memcpy(a,b,sizeof(a))
#define SZ(x) (int)x.size()
#define VI vector<int>
#define VII vector<pair<int,int>>
#define PI pair<int,int>
using namespace std;
typedef long long ll;
const ll P=1e9+7;
const int maxn=1e5+5;
const int maxm=2e5+5;
int head[maxn],ver[maxm],nex[maxm],tot,n,k;
void inline AddEdge(int x,int y){
ver[++tot]=y,nex[tot]=head[x],head[x]=tot;
}
int size[maxn],son[maxn],cnt[maxn],vis[maxn],dep[maxn],ar[maxn];
ll sum[maxn],ans[maxn];
void dfs(int x,int pa){
size[x]=1;
dep[x]=dep[pa]+1;
for(int i=head[x];i;i=nex[i]){
int y=ver[i];
if(y==pa) continue;
dfs(y,x);
size[x]+=size[y];
if(!son[x] || size[y]>size[son[x]]) son[x]=y;
}
}
void add(int x,int pa,int op){
sum[dep[x]]+=op*ar[x];
cnt[dep[x]]+=op;
for(int i=head[x];i;i=nex[i]){
int y=ver[i];
if(y==pa) continue;
add(y,x,op);
}
}
void qry(int x,int pa,int lca){
int d=k+2*dep[lca]-dep[x];
if(d>0) {
ans[lca]+=sum[d];
ans[lca]+=cnt[d]*ar[x];
}
for(int i=head[x];i;i=nex[i]){
int y=ver[i];
if(y==pa) continue;
qry(y,x,lca);
}
}
void dfs(int x,int pa,bool keep){
for(int i=head[x];i;i=nex[i]){
int y=ver[i];
if(y==pa || y==son[x]) continue;
dfs(y,x,0);
}
if(son[x]) dfs(son[x],x,1),vis[son[x]]=1;
for(int i=head[x];i;i=nex[i]){
int y=ver[i];
if(y==pa || vis[y]) continue;
qry(y,x,x);
add(y,x,1);
}
cnt[dep[x]]++;
sum[dep[x]]+=ar[x];
if(son[x]) vis[son[x]]=0;
if(!keep){
for(int i=head[x];i;i=nex[i]){
int y=ver[i];
if(y==pa) continue;
add(y,x,-1);
}
cnt[dep[x]]--;
sum[dep[x]]-=ar[x];
}
}
int main(){
sc2(n,k);
forn(i,1,n+1)
sc(ar[i]);
forn(i,1,n){
int x,y;
sc2(x,y);
AddEdge(x,y);
AddEdge(y,x);
}
dfs(1,0);
dfs(1,0,1);
forn(i,1,n+1)
printf("%lld%c",ans[i]," \n"[i==n]);
}
浙公网安备 33010602011771号