[kruskal重构树][线段树] Jzoj P5926 naive 的图
题解
- kruskal重构树其实就是加n-1个点,然后每个点存的是点权,然后有一些性质
- 那么可以对于每一个点都开一颗线段树,记录每种颜色出现的次数
- 合并两个集合的时候,可以用size小的集合里的颜色在size大的集合里查询
- 比说,现在颜色是k,那么就是找[k-L,k]和[k+L,inf]的个数
- 每条边的贡献就是总和*该条边的边权
代码
1 #include <cstdio> 2 #include <iostream> 3 #include <algorithm> 4 #define ll long long 5 #define N 200010 6 #define M 500010 7 #define inf 1000000000 8 using namespace std; 9 ll n,m,l,tot,num,sum,ans,c[N],f[N],d[N],head[N]; 10 struct edge { ll x,y,z; }a[M]; 11 struct node { ll to,from; }e[N]; 12 struct tr{ ll l,r,v; }tree[100*N]; 13 bool cmp(edge a,edge b) { return a.z<b.z; } 14 ll getfather(ll x) { return f[x]==x?x:f[x]=getfather(f[x]); } 15 void add(ll x,ll y) { e[++tot].to=y; e[tot].from=head[x]; head[x]=tot; } 16 ll query(ll d,ll l,ll r,ll L,ll R) 17 { 18 if (L>R) return 0; 19 if (l==L&&r==R) return tree[d].v; 20 ll mid=(l+r)>>1; 21 if (R<=mid) return query(tree[d].l,l,mid,L,R); 22 else if (L>mid) return query(tree[d].r,mid+1,r,L,R); 23 else return query(tree[d].l,l,mid,L,mid)+query(tree[d].r,mid+1,r,mid+1,R); 24 } 25 void dfs(ll d,ll x,ll fa) 26 { 27 sum+=query(d,0,inf,0,c[x]-l)+query(d,0,inf,c[x]+l,inf); 28 for (ll i=head[x];i;i=e[i].from) if (e[i].to!=fa) dfs(d,e[i].to,x); 29 } 30 void insert(ll d,ll l,ll r,ll x) 31 { 32 if (l==r) { tree[d].v++; return; } 33 ll mid=(l+r)>>1; 34 if (x<=mid) 35 { 36 if (!tree[d].l) tree[d].l=++num; 37 insert(tree[d].l,l,mid,x); 38 } 39 else 40 { 41 if (!tree[d].r) tree[d].r=++num; 42 insert(tree[d].r,mid+1,r,x); 43 } 44 tree[d].v=tree[tree[d].l].v+tree[tree[d].r].v; 45 } 46 void merge(ll l,ll r,ll L,ll R) 47 { 48 if (L==R) { tree[r].v+=tree[l].v; return; } 49 ll mid=(L+R)>>1; 50 if (tree[l].l&&tree[r].l) merge(tree[l].l,tree[r].l,L,mid); else if (tree[l].l) tree[r].l=tree[l].l; 51 if (tree[l].r&&tree[r].r) merge(tree[l].r,tree[r].r,mid+1,R); else if (tree[l].r) tree[r].r=tree[l].r; 52 tree[r].v=tree[tree[r].l].v+tree[tree[r].r].v; 53 } 54 int main() 55 { 56 freopen("graph.in","r",stdin),freopen("graph.out","w",stdout); 57 scanf("%lld%lld%lld",&n,&m,&l),num=n; 58 for (ll i=1;i<=n;i++) scanf("%lld",&c[i]),f[i]=i,d[i]=1,insert(i,0,inf,c[i]); 59 for (ll i=1;i<=m;i++) scanf("%lld%lld%lld",&a[i].x,&a[i].y,&a[i].z); 60 sort(a+1,a+m+1,cmp); 61 ll i=1,tot=0; 62 while (tot<n-1) 63 { 64 ll u=getfather(a[i].x),v=getfather(a[i].y); 65 if (u!=v) 66 { 67 if (d[u]>d[v]) swap(a[i].x,a[i].y),swap(u,v); 68 if (!l) sum=d[u]*d[v]; else sum=0,dfs(v,u,0); 69 ans+=sum*a[i].z,add(v,u),merge(u,v,0,inf); 70 f[u]=v,d[v]+=d[u];d[u]=0; 71 ++tot; 72 } 73 i++; 74 } 75 printf("%lld\n",ans); 76 }