dsu on tree

https://zhuanlan.zhihu.com/p/560661911

https://codeforces.com/contest/600/problem/E

非常裸的题 前面知乎链接里面有讲

const int N = 1e5 + 5;
int n, c[N], id[N], tot = 0;
struct node {
    int mx_cnt = 0;//最多的出现次数
    ll mx_sum = 0;//出现次数最多的颜色的编号和
    map<int, int>cnt;
    vector<int>list;
    void add(int u) {
        cnt[c[u]]++;
        if (cnt[c[u]] > mx_cnt)mx_cnt = cnt[c[u]], mx_sum = c[u];
        else if (cnt[c[u]] == mx_cnt)mx_sum += c[u];
        list.push_back(u);
    }
    int size() { return list.size(); }
}sub[N];
ll ans[N];
vector<int>g[N];
void dfs(int u, int fa) {
    id[u] = ++tot;
    int mx_son = -1, mx_sz = 0;
    for (int v : g[u]) {
        if (v == fa)continue;
        dfs(v, u);
        if (sub[id[v]].size() > mx_sz) {
            mx_sz = sub[id[v]].size();
            mx_son = v;
        }
    }
    if(mx_son!=-1)id[u] = id[mx_son];
    for (int v : g[u]) {
        if (v == fa)continue;
        if (v == mx_son)continue;
        for (int son : sub[id[v]].list)
            sub[id[u]].add(son);
    }
    sub[id[u]].add(u);
    ans[u] = sub[id[u]].mx_sum;
}
void slove() {
    cin >> n;
    for (int i = 1; i <= n; i++)cin >> c[i];
    for (int i = 1; i <= n - 1; i++) {
        int u, v; cin >> u >> v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs(1, 0);
    for (int i = 1; i <= n; i++)cout << ans[i] << " ";
    cout << endl;
}

https://codeforces.com/contest/1009/problem/F

对于树上每个节点 P,求一个最小的 K,使得其子树中到 P距离为K 的节点数最多

分析:

和上面一个题目类似 我们只需要将颜色改为深度 找到深度个数最多的前提下深度最小

#include<bits/stdc++.h>
using namespace std;
#define lowbit(x) x&(-x)
#define ll long long
const int N = 1e6 + 5;
int n, c[N], id[N], tot = 0;
struct node {
    int mx_cnt = 0;
    int mx_id = 1e9;
    map<int, int>cnt;
    vector<int>list;
    void add(int u) {
        cnt[c[u]]++;
        if (cnt[c[u]] > mx_cnt)mx_cnt = cnt[c[u]],mx_id=c[u];
        else if (cnt[c[u]] == mx_cnt)mx_id=min(c[u],mx_id);
        list.push_back(u);
    }
    int size() { return list.size(); }
}sub[N];
ll ans[N];
vector<int>g[N];
void dfs(int u, int fa) {
    id[u] = ++tot;
    int mx_son = -1, mx_sz = 0;
    for (int i=0;i<g[u].size();i++) {
        int v=g[u][i];
		if (v == fa)continue;
		c[v]=c[u]+1;
        dfs(v, u);
        if (sub[id[v]].size() > mx_sz) {
            mx_sz = sub[id[v]].size();
            mx_son = v;
        }
    }
    if(mx_son!=-1)id[u] = id[mx_son];
   for (int i=0;i<g[u].size();i++) {
        int v=g[u][i];
        if (v == mx_son||v==fa)continue;
        for(int j=0;j<sub[id[v]].list.size();j++)
        	sub[id[u]].add(sub[id[v]].list[j]);
}
    sub[id[u]].add(u);
    ans[u] =sub[id[u]].mx_id-c[u];
}
void solve() {
    cin >> n;
    for (int i = 1; i <= n - 1; i++) {
        int u, v;scanf("%d%d",&u,&v);
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs(1, 0);
    for (int i = 1; i <= n; i++)printf("%d\n",ans[i]);
}
int main(){
	solve();
	return 0;
}

https://www.luogu.com.cn/problem/U41492

分析:

也是一道很裸的题目

#include<bits/stdc++.h>
using namespace std;
#define lowbit(x) x&(-x)
#define ll long long
const int N = 1e5 + 5;
int n,id[N],c[N] ,tot;
struct node {
    int _cnt = 0;
    map<int, int>cnt;
    vector<int>list;
    void add(int u) {
    	if(!cnt[c[u]])_cnt++;
        cnt[c[u]]++; 
        list.push_back(u);
    }
    int size() { return list.size(); }
}sub[N];
ll ans[N];
vector<int>g[N];
void dfs(int u, int fa) {
    id[u] = ++tot;
    int mx_son = -1, mx_sz = 0;
    for (int i=0;i<g[u].size();i++) {
        int v=g[u][i];
		if (v == fa)continue;
        dfs(v, u);
        if (sub[id[v]].size() > mx_sz) {
            mx_sz = sub[id[v]].size();
            mx_son = v;
        }
    }
    if(mx_son!=-1)id[u] = id[mx_son];
   for (int i=0;i<g[u].size();i++) {
        int v=g[u][i];
        if (v == mx_son||v==fa)continue;
        for(int j=0;j<sub[id[v]].list.size();j++)
        	sub[id[u]].add(sub[id[v]].list[j]);
}
    sub[id[u]].add(u);
    ans[u] =sub[id[u]]._cnt;
}
void solve() {
    cin >> n;
    for (int i = 1; i <= n - 1; i++) {
        int u, v;scanf("%d%d",&u,&v);
        g[u].push_back(v);
        g[v].push_back(u);
    } 
	for (int i = 1; i <= n; i++)scanf("%d",&c[i]);
    dfs(1, 0);
   
    int q;
    scanf("%d",&q);
    while(q--){
    	int x;scanf("%d",&x);
    	printf("%d\n",ans[x]);
	}
}
int main(){
	solve();
	return 0;
}

https://codeforces.com/contest/1618/problem/G

const int maxnum = 2e5 + 5;
int a[maxnum], b[maxnum], ans[maxnum];
multiset<int>sta[2 * maxnum], stb[2 * maxnum];
int sum = 0;//答案
struct query { int k, t, id; };//差值,是否为询问,下标
int n, m, q;
struct DSU
{
	int fa[2 * maxnum];
	void init() { for (int i = 0; i < 2 * maxnum; i++)fa[i] = i; }
	int find(int x) { return x == fa[x] ? x : (fa[x] = find(fa[x])); }
	void merge(int x, int y) {
		x = find(x), y = find(y);
		if (sta[x].size() + stb[x].size() < sta[y].size() + stb[y].size())swap(x, y);
		for (int it : sta[y])sta[x].insert(it);
		for (int it : stb[y])stb[x].insert(it);
		while (sta[x].size() && stb[x].size() && *sta[x].begin() < *stb[x].rbegin()) {
			int mi = *sta[x].begin(), mx = *stb[x].rbegin();
			sum -= mi, sum += mx;
			sta[x].erase(sta[x].find(mi)), stb[x].erase(stb[x].find(mx));
			sta[x].insert(mx), stb[x].insert(mi);
		}
		fa[y] = x;
	}
} uf;

void slove() {
	uf.init();
	cin >> n >> m >> q;
	vector<pair<int, int>>v; v.push_back({ 0,0 });
	for (int i = 1; i <= n; i++) { cin >> a[i]; v.push_back({ a[i],1 }); sum += a[i]; }
	for (int i = 1; i <= m; i++) { cin >> b[i]; v.push_back({ b[i],0 }); }
	sort(v.begin(), v.end());//一定记得排序,pair的排序是默认按照先first后second
	vector<query>que;
	//这里将a b 分开
	for (int i = 1; i <= n + m; i++) {
		if (v[i].second)sta[i].insert(v[i].first);
		else stb[i].insert(v[i].first);
	}//排序后,保存两数之前的差值
	for (int i = 1; i < n + m; i++)que.push_back({ v[i + 1].first - v[i].first,0,i });
	for (int i = 1; i <= q; i++) {int qq; cin >> qq;que.push_back({ qq,1,i });}
	sort(que.begin(), que.end(), [](const query & a, const query & b) {
		if (a.k != b.k)return a.k < b.k;
		return a.t < b.t;
	});//lambda真好用(bushi
	for (auto q : que) {
		if (q.t)ans[q.id] = sum;
		else uf.merge(q.id, q.id + 1);
	}
	for (int i = 1; i <= q; i++)cout << ans[i] << endl;
}

https://www.luogu.com.cn/problem/U95602

问题的关键在于 如果维护每个节点的连续段 空间会炸 如果全局变量 各个子树之间又会产生影响

巧妙点:我们只能对轻儿子进行合并操作 所以先更新完重儿子 将关键值now变为dfn[son[x]]即可将亲儿子的连续段合并到重儿子中

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>

using namespace std;

typedef long long LL;
const int N=200009;
int n,head[N],cnt,p[N],a[N],son[N],siz[N],L[N],R[N],l[N],r[N],Index,DFN[N],rev[N],vis[N],now;
LL ans[N],Ans,s[N];
struct Edge
{
	int nxt,to;
}g[N*2];

void add(int from,int to)
{
	g[++cnt].nxt=head[from];
	g[cnt].to=to;
	head[from]=cnt;
}

void init()
{
	scanf("%d",&n);
	for (int i=2,x;i<=n;i++)
	{
		scanf("%d",&x);
		add(x,i),add(i,x);
	}
	for (int i=1,x;i<=n;i++)
		scanf("%d",&x),a[x]=i;
	for (int i=1;i<=n;i++)
		scanf("%d",&p[i]);
}

void dfs(int x,int fa)
{
	siz[x]=1,DFN[x]=++Index,rev[Index]=x;
	for (int i=head[x];i;i=g[i].nxt)
	{
		int v=g[i].to;
		if(v==fa)
			continue;
		dfs(v,x);
		siz[x]+=siz[v];
		if(siz[son[x]]<siz[v])
			son[x]=v;	
	}
	L[x]=DFN[x],R[x]=Index;
}

LL calc(int x) { return 1ll*x*(x+1)/2; }

void Insert(int x,int k)
{
	vis[x]=now,l[x]=r[x]=x;
	if(vis[x+1]!=now) l[x+1]=r[x+1]=0;
	if(vis[x-1]!=now) l[x-1]=r[x-1]=0;
	int xl=0,xr=0;
	if(l[x-1]&&r[x+1])
	{
		xl=x-l[x-1],xr=r[x+1]-x;
		r[l[x-1]]=r[x+1],l[r[x+1]]=l[x-1];
	}
	else if(l[x-1])
	{
		xl=x-l[x-1];
		r[l[x-1]]=x,l[x]=l[x-1];
	}
	else if(r[x+1])
	{
		xr=r[x+1]-x;
		l[r[x+1]]=x,r[x]=r[x+1];
	}
	ans[k]=ans[k]-calc(xl)-calc(xr)+calc(xl+xr+1);
}

void DFS(int x,int fa)
{
	for (int i=head[x];i;i=g[i].nxt)
	{
		int v=g[i].to;
		if(v==fa||v==son[x])
			continue;
		DFS(v,x);
	}
	now=DFN[x];
	if(son[x])
	{
		DFS(son[x],x);
		ans[x]+=ans[son[x]];
		DFN[x]=now=DFN[son[x]];
		for (int i=head[x];i;i=g[i].nxt)
		{
			int v=g[i].to;
			if(v==fa||v==son[x])
				continue;
			for (int j=L[v];j<=R[v];j++)
				Insert(a[rev[j]],x);
		}
	}
	Insert(a[x],x);
	LL k=ans[x];
	for (int i=head[x];i;i=g[i].nxt)
	{
		int v=g[i].to;
		if(v==fa) continue;
		k-=ans[v];
	}
	Ans+=k*p[x];
}

void work()
{
	dfs(1,-1);
	DFS(1,-1);
	printf("%lld\n",Ans);
}

int main()
{
	init();
	work();
	return 0;
}

https://www.luogu.com.cn/problem/CF570D

如果对每个节点的子集建立一个cnt[26]的map是会超空间的 本题巧妙地讲统计个数变成了异或问题 这样每个节点就只需要开一个map

还有另一种版本 全局变量

原版本code:

#include<bits/stdc++.h>
using namespace std;
#define lowbit(x) x&(-x)
#define ll long long
const int N = 5e5 + 5;
int n,m,id[N],dp[N],tot;
vector<int>pd[N],num[N];
int ans[N];
char c[N];
struct node {
    map<int,int>cnt;
    vector<int>list;
    void add(int u) {
    	int I=c[u]-'a';
    	cnt[dp[u]]^=(1<<I);
        list.push_back(u);
    }
    int size() { return list.size(); }
    int calc(int d){
    	int res=0;
    	for(int i=0;i<26;i++)
    	if(cnt[d]&(1<<i))res++;
    	return res;
	}
}sub[N];
vector<int>g[N];
void dfs(int u) {
    id[u] = ++tot;
    int mx_son = -1, mx_sz = 0;
    for (int i=0;i<g[u].size();i++) {
        int v=g[u][i];
		dp[v]=dp[u]+1;
        dfs(v);
        if (sub[id[v]].size() > mx_sz) {
            mx_sz = sub[id[v]].size();
            mx_son = v;
        }
    }
    if(mx_son!=-1)id[u] = id[mx_son];
   for (int i=0;i<g[u].size();i++) {
        int v=g[u][i];
        if (v == mx_son)continue;
        for(int j=0;j<sub[id[v]].list.size();j++)
        	sub[id[u]].add(sub[id[v]].list[j]);
}
    sub[id[u]].add(u);
    for(int i=0;i<pd[u].size();i++){
    	int ii=num[u][i];
    	int ak=pd[u][i];
    	int res=sub[id[u]].calc(ak);
    	if(res<=1)
		ans[ii]=true;
		else ans[ii]=false;
	}
}
void solve() {
    cin >> n>>m;
    for (int i = 2; i <= n ; i++) {
        int u;scanf("%d",&u);
        g[u].push_back(i);
    } 
	for (int i = 1; i <= n; i++)cin>>c[i];
    for(int i=1;i<=m;i++){
    	int x,dep;scanf("%d%d",&x,&dep);
    	dep--;
    	pd[x].push_back(dep);
    	num[x].push_back(i);
	}dfs(1);
	for(int i=1;i<=m;i++)
	if(ans[i])printf("Yes\n");
	else printf("No\n");
}
int main(){
	solve();
	return 0;
}

全局变量版本code:

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
#include<map>
#include<set>
#define ll long long
using namespace std;
inline int read(){
    int x=0,o=1;char ch=getchar();
    while(ch!='-'&&(ch<'0'||ch>'9'))ch=getchar();
    if(ch=='-')o=-1,ch=getchar();
    while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar();
    return x*o;
}
const int N=500005;
int n,m,visit[N],ans[N];char s[N];
int size[N],son[N],sum[N],val[N],dep[N];
int tot,head[N],nxt[N],to[N];
inline void add(int a,int b){
	nxt[++tot]=head[a];head[a]=tot;to[tot]=b;
}
struct node{int h,nxt;}a[N];int Head[N];
inline void Add(int v,int h,int id){
    a[id]=(node){h,Head[v]};Head[v]=id;
}
inline void pre_dfs(int u){
	size[u]=1;
	for(int i=head[u];i;i=nxt[i]){
		int v=to[i];dep[v]=dep[u]+1;
		pre_dfs(v);size[u]+=size[v];
		if(size[v]>size[son[u]])son[u]=v;
	}
}
inline void update(int u){
	sum[dep[u]]^=val[u];//异或和a
	for(int i=head[u];i;i=nxt[i]){
		int v=to[i];if(visit[v])continue;
		update(v);
	}
}
inline bool calc(int x){//计算这个数为1的位有多少个
	int cnt=0;
	for(int i=0;i<=25;++i)if(x&(1<<i))++cnt;
	return cnt<=1;
}
inline void dfs(int u,int keep){
	for(int i=head[u];i;i=nxt[i]){
		int v=to[i];if(son[u]==v)continue;
		dfs(v,0);
	}
	if(son[u])dfs(son[u],1),visit[son[u]]=1;update(u);
	for(int i=Head[u];i;i=a[i].nxt)ans[i]=calc(sum[a[i].h]);
	visit[son[u]]=0;if(!keep)update(u);
}
int main(){
	n=read(),m=read();
	for(int i=2;i<=n;++i){int x=read();add(x,i);}//直接建有向边
	scanf("%s",s+1);for(int i=1;i<=n;++i)val[i]=1<<(s[i]-'a');//给节点赋权值
	for(int i=1,v,h;i<=m;++i)v=read(),h=read(),Add(v,h,i);//把询问离线
	dep[1]=1;//刚开始这里没赋值,调了两个小时
    pre_dfs(1);dfs(1,1);for(int i=1;i<=m;++i)puts(ans[i]?"Yes":"No");
    return 0;
}


https://www.luogu.com.cn/problem/CF246E

题意:给定一片森林,每次询问一个节点的 K-Son 共有个多少不同的名字。一个节点的 K-Son 即为在该节点子树内的,深度是该节点深度加 K的节点。

分析:

这样子就不能用局部变量 只能用全局变量 下面代码非常标准的全局变量dsu on tree

#include<bits/stdc++.h>
using namespace std;
#define maxn 100010
#define int long long
int n,x,y,m;
string a[maxn];
vector<pair<int,int> >v[maxn];
vector<pair<int,int> >::iterator it;
int head[maxn],Next[maxn],ver[maxn],tot;
void add(int x,int y){
	ver[++tot]=y;
	Next[tot]=head[x];
	head[x]=tot;
}
int dep[maxn],son[maxn],siz[maxn];
void dfs(int x){
	siz[x]=1;
	for(int i=head[x];i;i=Next[i]){
		int y=ver[i];
		dep[y]=dep[x]+1;
		dfs(y);
		siz[x]+=siz[y];
		if(siz[son[x]]<siz[y])son[x]=y;
	}
}
set<string>S[maxn*2];
int Ans[maxn];
void del(int x){
	S[dep[x]].clear();
	for(int i=head[x];i;i=Next[i])del(ver[i]);
}
void upd(int x){
	S[dep[x]].insert(a[x]);
}
void add(int x){
	upd(x);
	for(int i=head[x];i;i=Next[i])add(ver[i]);
}
int vis[maxn];
void dsu(int x){
	vis[x]=1;
	for(int i=head[x];i;i=Next[i])
	if(ver[i]!=son[x])dsu(ver[i]),del(ver[i]);
	if(son[x])dsu(son[x]);
	for(int i=head[x];i;i=Next[i])
	if(ver[i]!=son[x])add(ver[i]);
	upd(x);
	for(it=v[x].begin();it!=v[x].end();it++)Ans[it->second]=S[it->first].size();
}
signed main(){
	cin>>n;
	for(int i=1;i<=n;i++){
		cin>>a[i]>>x;
		if(x)add(x,i);
	}
	for(int i=1;i<=n;i++)
		if(!dep[i])dep[i]=1,dfs(i);
	cin>>m;
	for(int i=1;i<=m;i++)cin>>x>>y,v[x].push_back(make_pair(y+dep[x],i));
	for(int i=1;i<=n;i++)if(!vis[i])dsu(i),del(i);
	for(int i=1;i<=m;i++)cout<<Ans[i]<<endl;
}
posted @ 2022-12-06 18:22  wzx_believer  阅读(30)  评论(0)    收藏  举报