[GXOI/GZOI2019]旧词——树链剖分+线段树

题目链接:

[GXOI/GZOI2019]旧词

对于$k=1$的情况,可以参见[LNOI2014]LCA,将询问离线然后从$1$号点开始对这个点到根的路径链修改,每次询问就是对询问点到根路径链查询即可。

可以发现,如果一个点的贡献被记入答案,那么这个点到根的路径上所有点的贡献都会被记入答案。

那么对于$k>1$的情况,只要每次将路径上点$u$的权值都$+1$变成每次将路径上点$u$的权值都$+(dep[u]^k-(dep[u]-1)^k)$即可。

同样用线段树维护树剖序的区间权值和即可。

#include<set>
#include<map>
#include<queue>
#include<stack>
#include<cmath>
#include<vector>
#include<bitset>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
const int mod=998244353;
int n,m,k;
int ans[50010];
int p[50010];
int son[50010];
int size[50010];
int f[50010];
int tot;
int head[50010];
int to[50010];
int nex[50010];
int dep[50010];
int top[50010];
int s[50010];
int q[50010];
int dfn;
int sum[400010];
int num[400010];
int tag[400010];
struct lty
{
	int x,y,id;
}a[50010];
bool cmp(lty a,lty b)
{
	return a.x<b.x;
}
int quick(int x,int y)
{
	int res=1;
	while(y)
	{
		if(y&1)
		{
			res=1ll*res*x%mod;
		}
		x=1ll*x*x%mod;
		y>>=1;
	}
	return res;
}
void add_edge(int x,int y)
{
	nex[++tot]=head[x];
	head[x]=tot;
	to[tot]=y;
}
int add(int x,int y)
{
	if(x+y<mod)
	{
		return x+y;
	}
	else
	{
		return x+y-mod;
	}
}
void dfs(int x)
{
	size[x]=1;
	for(int i=head[x];i;i=nex[i])
	{
		dep[to[i]]=dep[x]+1;
		dfs(to[i]);
		size[x]+=size[to[i]];
		if(size[to[i]]>size[son[x]])
		{
			son[x]=to[i];
		}
	}
}
void dfs2(int x,int tp)
{
	top[x]=tp;
	s[x]=++dfn;
	q[dfn]=x;
	if(son[x])
	{
		dfs2(son[x],tp);
	}
	for(int i=head[x];i;i=nex[i])
	{
		if(to[i]!=son[x])
		{
			dfs2(to[i],to[i]);
		}
	}
}
void pushup(int rt)
{
	sum[rt]=add(sum[rt<<1],sum[rt<<1|1]);
	num[rt]=add(num[rt<<1],num[rt<<1|1]);
}
void pushdown(int rt)
{
	if(tag[rt])
	{
		tag[rt<<1]=add(tag[rt],tag[rt<<1]);
		tag[rt<<1|1]=add(tag[rt],tag[rt<<1|1]);
		sum[rt<<1]=add(sum[rt<<1],1ll*tag[rt]*num[rt<<1]%mod);
		sum[rt<<1|1]=add(sum[rt<<1|1],1ll*tag[rt]*num[rt<<1|1]%mod);
		tag[rt]=0;
	}
}
void build(int rt,int l,int r)
{
	if(l==r)
	{
		num[rt]=p[dep[q[l]]];
		return ;
	}
	int mid=(l+r)>>1;
	build(rt<<1,l,mid);
	build(rt<<1|1,mid+1,r);
	pushup(rt);
}
void change(int rt,int l,int r,int L,int R)
{
	if(L<=l&&r<=R)
	{
		tag[rt]=add(tag[rt],1);
		sum[rt]=add(sum[rt],num[rt]);
		return ;
	}
	int mid=(l+r)>>1;
	pushdown(rt);
	if(L<=mid)
	{
		change(rt<<1,l,mid,L,R);
	}
	if(R>mid)
	{
		change(rt<<1|1,mid+1,r,L,R);
	}
	pushup(rt);
}
int query(int rt,int l,int r,int L,int R)
{
	if(L<=l&&r<=R)
	{
		return sum[rt];
	}
	int mid=(l+r)>>1;
	int res=0;
	pushdown(rt);
	if(L<=mid)
	{
		res=add(res,query(rt<<1,l,mid,L,R));
	}
	if(R>mid)
	{
		res=add(res,query(rt<<1|1,mid+1,r,L,R));
	}
	return res;
}
void modify(int x)
{
	while(top[x]!=1)
	{
		change(1,1,n,s[top[x]],s[x]);
		x=f[top[x]];
	}
	change(1,1,n,1,s[x]);
}
int ask(int x)
{
	int res=0;
	while(top[x]!=1)
	{
		res=add(res,query(1,1,n,s[top[x]],s[x]));
		x=f[top[x]];
	}
	res=add(res,query(1,1,n,1,s[x]));
	return res;
}
int main()
{
	scanf("%d%d%d",&n,&m,&k);
	for(int i=1;i<=n;i++)
	{
		p[i]=(quick(i,k)-quick(i-1,k)+mod)%mod;
	}
	dep[1]=1;
	for(int i=2;i<=n;i++)
	{
		scanf("%d",&f[i]);
		add_edge(f[i],i);
	}
	dfs(1);
	dfs2(1,1);
	build(1,1,n);
	for(int i=1;i<=m;i++)
	{
		scanf("%d%d",&a[i].x,&a[i].y);
		a[i].id=i;
	}
	sort(a+1,a+1+m,cmp);
	int now=0;
	for(int i=1;i<=m;i++)
	{
		while(now<a[i].x)
		{
			now++;
			modify(now);
		}
		ans[a[i].id]=ask(a[i].y);
	}
	for(int i=1;i<=m;i++)
	{
		printf("%d\n",ans[i]);
	}
}
posted @ 2019-04-17 14:21  The_Virtuoso  阅读(331)  评论(0编辑  收藏  举报