【YbtOJ#20056】树上取模

题目

题目链接:http://noip.ybtoj.com.cn/problem/20056
给出一棵树,点有点权,要求支持一个子树内所有树取模 \(k\),单点修改,求一条链的和。
\(n\leq 10^5,1\leq a_i,k\leq 10^8\)

思路

没有操作一就是裸的树剖。
发现一个数 \(x\) 取模 \(k\) 之后只有两种情况:

  • \(x<k\) 时,\(x\bmod k\) 依然为 \(x\)
  • \(x\geq k\) 时,\(x\bmod k\) 一定 \(\leq \frac{x}{2}\)
    所以每一个数在不修改的前提下,只会被取模 \(\log\) 次。那么线段树维护区间最大值即可。如果一个区间最大值不小于 \(k\) 那么就往这个区间走。
    时间复杂度 \(O(m\log n(\log n+\log A))\),其中 \(A=\max(a_i)\)

代码

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int N=100010,LG=20;
int n,Q,tot,a[N],head[N],fa[N][LG+1],dep[N],son[N],id[N],rk[N],size[N],top[N];

struct edge
{
	int next,to;
}e[N*2];

void add(int from,int to)
{
	e[++tot].to=to;
	e[tot].next=head[from];
	head[from]=tot;
}

void dfs1(int x,int f)
{
	fa[x][0]=f; size[x]=1; dep[x]=dep[f]+1;
	for (int i=1;i<=LG;i++)
		fa[x][i]=fa[fa[x][i-1]][i-1];
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=fa[x][0])
		{
			dfs1(v,x);
			size[x]+=size[v];
			if (size[v]>size[son[x]]) son[x]=v;
		}
	}
}

void dfs2(int x,int tp)
{
	top[x]=tp; id[x]=++tot; rk[tot]=x;
	if (son[x]) dfs2(son[x],tp);
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=fa[x][0] && v!=son[x]) dfs2(v,v);
	}
}

int lca(int x,int y)
{
	if (dep[x]<dep[y]) swap(x,y);
	for (int i=LG;i>=0;i--)
		if (dep[fa[x][i]]>=dep[y]) x=fa[x][i];
	if (x==y) return x;
	for (int i=LG;i>=0;i--)
		if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
	return fa[x][0];
}

struct SegTree
{
	int l[N*4],r[N*4],maxn[N*4];
	ll sum[N*4];
	
	void pushup(int x)
	{
		sum[x]=sum[x*2]+sum[x*2+1];
		maxn[x]=max(maxn[x*2],maxn[x*2+1]);
	}
	
	void build(int x,int ql,int qr)
	{
		l[x]=ql; r[x]=qr;
		if (ql==qr)
		{
			sum[x]=maxn[x]=a[rk[ql]];
			return;
		}
		int mid=(ql+qr)>>1;
		build(x*2,ql,mid); build(x*2+1,mid+1,qr);
		pushup(x);
	}
	
	void updmod(int x,int ql,int qr,int p)
	{
		if (l[x]==r[x])
		{
			sum[x]%=p; maxn[x]%=p;
			return;
		}
		int mid=(l[x]+r[x])>>1;
		if (maxn[x*2]>=p && ql<=mid) updmod(x*2,ql,qr,p);
		if (maxn[x*2+1]>=p && qr>mid) updmod(x*2+1,ql,qr,p);
		pushup(x);
	}
	
	void update(int x,int k,int v)
	{
		if (l[x]==k && r[x]==k)
		{
			sum[x]=maxn[x]=v;
			return;
		}
		int mid=(l[x]+r[x])>>1;
		if (k<=mid) update(x*2,k,v);
			else update(x*2+1,k,v);
		pushup(x);
	}
	
	ll query(int x,int ql,int qr)
	{
		if (l[x]==ql && r[x]==qr)
			return sum[x];
		int mid=(l[x]+r[x])>>1;
		if (qr<=mid) return query(x*2,ql,qr);
		if (ql>mid) return query(x*2+1,ql,qr);
		return query(x*2,ql,mid)+query(x*2+1,mid+1,qr);
	}
}seg;

ll query(int y,int x)
{
	ll ans=0;
	while (dep[top[x]]>dep[y])
	{
		ans+=seg.query(1,id[top[x]],id[x]);
		x=fa[top[x]][0];
	}
	ans+=seg.query(1,id[y],id[x]);
	return ans;
}

int main()
{
	freopen("flower.in","r",stdin);
	freopen("flower.out","w",stdout);
	memset(head,-1,sizeof(head));
	scanf("%d%d",&n,&Q);
	for (int i=1,x,y;i<n;i++)
	{
		scanf("%d%d",&x,&y);
		add(x,y); add(y,x);
	}
	for (int i=1;i<=n;i++)
		scanf("%d",&a[i]);
	tot=0;
	dfs1(1,0); dfs2(1,1);
	seg.build(1,1,n);
	int type,x,y;
	while (Q--)
	{
		scanf("%d%d%d",&type,&x,&y);
		if (type==1) seg.updmod(1,id[x],id[x]+size[x]-1,y);
		if (type==2) seg.update(1,id[x],y);
		if (type==3)
		{
			int p=lca(x,y);
			printf("%lld\n",query(p,x)+query(p,y)-query(p,p));
		}
	}
	return 0;
}
posted @ 2020-09-15 15:41  stoorz  阅读(160)  评论(0编辑  收藏  举报