[题解]P4116 Qtree3
题目描述没看懂(雾)
简单解释一下:
对于节点\(u\),将子树\(u\)中的权值从大到小排序,记“权值乘排名之和”为\(u\)的贡献。
输出总贡献。
下文定义\(w[u]\)为\(u\)的权值,\(siz[u]\)为子树\(u\)的大小。
考虑每个节点对答案的贡献,看起来比较困难。
转而考虑每个权值对答案的贡献。
不难发现,如果将子树\(u\)中的节点按权值从大到小排序,那么这些点的排名依次是\(siz[u],siz[u]-1,\dots,1\)。
所以我们按权值从大到小遍历每个节点。需要统计它的贡献的节点,就是它到根节点路径上的所有节点。
遍历这些节点,在\(v\)点累加的贡献是\(siz[v]\times w[u]\),并且将\(siz[v]\)减去\(1\)。
这样,\(siz[v]\)表示的始终是\(u\)在子树\(v\)中的排名。因为我们的节点是按权值从大到小统计的。
因此我们需要支持的操作:
- 将\(u\)到根节点路径上的\(siz\)统一减去\(1\)。
- 统计\(u\)到根节点的\(siz\)之和(将所求的和\(\times w[u]\)累入答案即可)。
可以用树剖实现。
时间复杂度\(O(n\log^2 n)\)。
点击查看代码
#include<bits/stdc++.h>
#define int long long
using namespace std;
typedef long long ll;
const int N=5e5+10,P=1e9+7;
int n,head[N],fa[N],dep[N],siz[N],son[N],top[N],dfn[N],tim,w[N],idx;
ll ans;
struct edge{int nxt,to;}e[N<<1];
void add(int u,int v){e[++idx]={head[u],v},head[u]=idx;}
inline int lb(int x){return x&-x;}
struct BIT{
int s1[N],s2[N];
void chp(int x,int v){
for(int i=x;i<=n;i+=lb(i)) s1[i]+=v,s2[i]+=v*x;
}
void chr(int x,int y,int v){
chp(x,v),chp(y+1,-v);
}
int query(int x){
int ans=0;
for(int i=x;i;i-=lb(i)) ans+=(x+1)*s1[i]-s2[i];
return ans;
}
int query(int x,int y){
return query(y)-query(x-1);
}
}bit;
struct Node{
int p,w;
}p[N];
void dfs1(int u){
siz[u]=1,dep[u]=dep[fa[u]]+1;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa[u]) continue;
fa[v]=u,dfs1(v),siz[u]+=siz[v];
if(siz[v]>siz[son[u]]) son[u]=v;
}
}
void dfs2(int u,int t){
top[u]=t,dfn[u]=++tim,bit.chr(dfn[u],dfn[u],siz[u]);
if(!son[u]) return;
dfs2(son[u],t);
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v!=fa[u]&&v!=son[u]) dfs2(v,v);
}
}
void ch_chain(int u,int v,int iv){
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]]) swap(u,v);
bit.chr(dfn[top[u]],dfn[u],iv);
u=fa[top[u]];
}
if(dep[u]<dep[v]) swap(u,v);
bit.chr(dfn[v],dfn[u],iv);
}
ll query(int u,int v){
ll ans=0;
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]]) swap(u,v);
ans+=bit.query(dfn[top[u]],dfn[u]);
ans%=P;
u=fa[top[u]];
}
if(dep[u]<dep[v]) swap(u,v);
ans+=bit.query(dfn[v],dfn[u]);
return ans%P;
}
signed main(){
cin>>n;
for(int i=1,u,v;i<n;i++) cin>>u>>v,add(u,v),add(v,u);
for(int i=1;i<=n;i++) cin>>w[i],p[i]={i,w[i]};
dfs1(1),dfs2(1,1);
sort(p+1,p+1+n,[](Node a,Node b){return a.w>b.w;});
for(int i=1;i<=n;i++){
(ans+=query(1,p[i].p)%P*p[i].w)%=P;
ch_chain(1,p[i].p,-1);
}
cout<<ans<<"\n";
return 0;
}
浙公网安备 33010602011771号