Loading

CW 模拟赛 T1.Soso 的并查集写挂了

一轮复习

题面

似乎有原题, 但是很偏
挂个 pdf
题面下载

算法

暴力

很显然, 只需要在并查集维护时稍微加上一点细节

#include <cstdio>
using namespace std;
int n,m,fa[500010],a[500010];
long long ans=0;
int find(int x){
  	ans+=a[x];
	ans%=998244353;
	if(fa[x]==x) return x;
	return find(fa[x]);
}
void merge(int x,int y){
	int tx=find(x);
  	int ty=find(y);
  	fa[ty]=tx;
}
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++){
		scanf("%d",&a[i]);
		fa[i]=i;
	}
	for(int i=1,x,y;i<=m;i++){
		scanf("%d%d",&x,&y);
		merge(x,y);
	}
	printf("%lld\n",ans);
	return 0;
}

正解

观察到 TLE 的原因是每次查询使用了大量时间, 考虑优化

带权并查集

并查集的优点是能够通过路径压缩优化复杂度

维护边权

这是简单的, 只需要稍稍修改一下 \(\rm{find}\) 函数即可 (其中 \(d\) 存储到根的边权之和)

inline int find(int x) {
	if(fa[x] == x) return x;
	int root = find(fa[x]); //注意一下写法, 先将find(fa[x])存放在root中, 否则会出错 
	d[x] += d[fa[x]];
	return fa[x] = root;
}
维护点权

按照维护边权的方法写完之后发现会出问题
观察到本质上是因为每一次路径压缩都会重复计算最上层的根节点
于是想办法消除其的影响

代码(后补)
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int mod=998244353;
int n,m;
int a[500010];
pair<int,int> fa[500010];
int ans;
void init(){
	for(int i=1;i<=n;i++){
		fa[i]={i,a[i]};
	}
	return;
}
pair<int,int> find(int x){
	if(fa[x].first==x)return fa[x];
	auto t=find(fa[x].first);
	t.second+=fa[x].second;
	fa[x].second=t.second-fa[t.first].second; 
	fa[x].first=t.first;
	fa[x].second%=mod;
	t.second%=mod;
	return t;
}
signed main(){
	cin>>n>>m;
	for(int i=1;i<=n;i++){
		cin>>a[i];
	}
	init();
	while(m--){
		int x,y;
		cin>>x>>y;
		auto fx=find(x),fy=find(y);
		ans+=fx.second+fy.second;
		if(fx.first!=fy.first){
			fa[fy.first].first=fx.first;
		}
	}
	cout<<ans%mod;
	return 0; 
} 

树上差分

观察到查询操作在最终的树上都是一条链, 考虑并查集 + 建树, 维护树上差分操作

#include <cstdio>
#include <vector>
#include <cstring>
#include <numeric>
#include <algorithm>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr,##__VA_ARGS__)
#else
#define debug(...) void(0)
#endif
typedef long long LL;
const int P=998244353;
vector<int> g[1<<19];
template<int N> struct dsu{
	int fa[N+10],cnt;
	explicit dsu(int n=N):cnt(n){iota(fa+1,fa+n+1,1);}
	int find(int x){return fa[x]==x?x:fa[x]=find(fa[x]);}
	void merge(int x,int y){if(x=find(x),y=find(y),x!=y) cnt--,fa[y]=x,g[x].push_back(y);}
};
int n,m,cnt[1<<19],a[1<<19];
dsu<1<<19> s;
LL ans=0;
void dfs(int u){
	for(int v:g[u]) dfs(v),cnt[u]+=cnt[v];
	ans=(ans+1ll*cnt[u]*a[u])%P;
}
int main(){
//	#ifdef LOCAL
//	 	freopen("input.in","r",stdin);
//	#endif
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++) scanf("%d",&a[i]);
	for(int i=1,u,v;i<=m;i++){
		scanf("%d%d",&u,&v);
		cnt[u]++,cnt[v]++;
		cnt[s.find(u)]--,cnt[s.find(v)]--;
		ans=(ans+a[s.find(u)]+a[s.find(v)])%P;
		s.merge(u,v);
	}
	for(int i=1;i<=n;i++) if(s.find(i)==i) dfs(i);
	printf("%lld\n",ans);
	return 0;
}

二轮复习

现在回过头来复习

初步分析

首先考虑题意
先按照题目中的要求建出最终形态的树, 然后再表示每次 \(\rm{find}\) 操作的花费
pEFRg61.png

如何处理这些花费

考虑在每次 \(\rm{find}\) 的时候记录链的开头和结尾, 然后最后进行处理
这个时候需要引入树上差分来处理这个问题

树上差分

和链上(线性) 差分一样, 我们考虑对于每一个点进行处理之后搞一个前缀和
这个题一个很好的性质是每一个更改一定在链上, 避开了对 \(\rm{LCA}\) 的处理, 但是这里还是放上完整的点权树上差分和边权树上差分的做法

  • 点差分:
    当我们需要对树上的路径进行访问, 比如查询路径上的点被访问的次数时, 可以使用点差分
    对于一次路径访问, 我们找到路径两端节点的 \(\rm{LCA}\) , 然后对路径上的点进行操作
    具体的差分操作如下: diff[u] += val diff[v] += val diff[lca] -= val diff[father_of_lca] -= val 其中 diff 数组是点权的差分数组, lca 是最近公共祖先, father_of_lcalca 的父节点

  • 边差分:
    当我们需要对路径中的边进行操作时, 就需要使用边差分
    边差分的操作与点差分类似, 但是由于在边上直接进行差分比较困难, 因此我们将操作转移到相邻的节点上
    边差分的操作如下: diff[u] += val diff[v] += val diff[lca] -= val diff[father_of_lca] -= val 这里的操作实质上是对两段路径进行差分

  • 最终答案: 这个是简单的, 遍历一遍维护前缀和即可

最终做法

容易发现每次维护树上差分, 直接做即可

总结

考场上并没有想到怎么维护点权, 自己推的能力需要加强
转化能力是重要的

树上链和考虑使用树上差分


三轮复习

pV0WvLD.png

思路

很显然上次我没有完全搞懂这个问题
考虑分析问题

我们首先简单的将最终的树建出来, 然后每次的花费可以视作一条链的权值和
不难发现快速维护一条祖先-后代链的权值和是比较简单的, 我们预处理树上前缀和即可

#include <bits/stdc++.h>
#define int long long
const int MAXN = 5e5 + 20;

int n, Q;
int val[MAXN];

/*并查集*/
struct DSU {
    int fa[MAXN];
    DSU() { for (int i = 1; i <= 500000; i++) fa[i] = i; }

    int find(int x) { return fa[x] = (fa[x] == x ? x : find(fa[x])); }
    void merge(int lhs, int rhs) {
        int fax = find(lhs), fay = find(rhs);
        if (fax == fay) return;
        fa[fay] = fax;
    }
} dsu;

std::vector<std::pair<int, int> > qry, mrg;

struct EDGE { int u, v, nxt; } e[MAXN];
int head[MAXN], cnt = 0;
void addedge(int u, int v) { e[cnt] = {u, v, head[u]}, head[u] = cnt++; }
int dep[MAXN], pre[MAXN], fa[MAXN];


signed main() {
    memset(head, -1, sizeof(head));

    scanf("%lld %lld", &n, &Q);
    for (int i = 1; i <= n; i++) scanf("%lld", &val[i]);

    for (int i = 1, x, y; i <= Q; i++) {
        scanf("%lld %lld", &x, &y);
        if (dsu.find(x) != dsu.find(y)) addedge(dsu.find(x), dsu.find(y)), addedge(dsu.find(y), dsu.find(x));
        qry.push_back({x, dsu.find(x)}), qry.push_back({y, dsu.find(y)});
        dsu.merge(x, y); mrg.push_back({x, y});
    }

    auto dfs = [&](auto &&dfs, int u, int fat) -> void {
        dep[u] = dep[fat] + 1;
        pre[u] = pre[fat] + val[u];
        fa[u] = fat;
        for (int i = head[u]; ~i; i = e[i].nxt) {
            int v = e[i].v;
            if (v == fat) continue;
            dfs(dfs, v, u);
        }
    }; for (int i = 1; i <= n; i++) if (dsu.find(i) == i) dfs(dfs, i, 0);

    int ans = 0;
    for (auto [x, y] : qry) {
        if (dep[x] > dep[y]) std::swap(x, y);
        ans += pre[y] - pre[fa[x]]; ans %= 998244353;
    }

    printf("%lld", ans);



    return 0;
}
posted @ 2024-10-20 11:54  Yorg  阅读(53)  评论(0)    收藏  举报