loj #2491. 「BJOI2018」求和
#2491. 「BJOI2018」求和
题目描述
master 对树上的求和非常感兴趣。他生成了一棵有根树,并且希望多次询问这棵树上一段路径上所有节点深度的 kkk 次方和,而且每次的 kkk 可能是不同的。此处节点深度的定义是这个节点到根的路径上的边数。 他把这个问题交给了 pupil,但 pupil 并不会这么复杂的操作,你能帮他解决吗?
输入格式
第一行包含一个正整数 nnn,表示树的节点数。
之后 n−1n-1n−1 行每行两个空格隔开的正整数 i,ji,ji,j,表示树上的一条连接点 iii 和点 jjj 的边。
之后一行一个正整数 mmm,表示询问的数量。
之后每行三个空格隔开的正整数 i,j,ki,j,ki,j,k,表示询问从点 iii 到点 jjj 的路径上所有节点深度的 kkk 次方和。由于这个结果可能非常大,输出其对 998244353998244353998244353取模的结果。
树的节点从 111 开始标号,其中 111 号节点为树的根。
输出格式
对于每组数据输出一行一个正整数表示取模后的结果。
样例
样例输入
5
1 2
1 3
2 4
2 5
2
1 4 5
5 4 45
样例输出
33
503245989
样例解释
以下用 d(i)d\left(i\right)d(i) 表示第 iii 个节点的深度。
对于样例中的树,有 d(1)=0,d(2)=1,d(3)=1,d(4)=2,d(5)=2d\left(1\right)=0,d\left(2\right)=1,d\left(3\right)=1,d\left(4\right)=2,d\left(5\right)=2d(1)=0,d(2)=1,d(3)=1,d(4)=2,d(5)=2。
因此第一个询问答案为 (25+15+05)mod998244353=33,第二个询问答案为 (245+145+245)mod998244353=503245989。
数据范围与提示
对于30%30\%30%的数据,1≤n,m≤1001 \leq n,m \leq 1001≤n,m≤100;
对于60%60\%60%的数据,1≤n,m≤10001 \leq n,m \leq 10001≤n,m≤1000;
对于100%100\%100%的数据,1≤n,m≤300000,1≤k≤501 \leq n,m \leq 300000,1 \leq k \leq 501≤n,m≤300000,1≤k≤50。
/* 可以说是树剖裸题了,k非常小,直接预处理即可 */ #include<iostream> #include<cstdio> #include<cstring> #define mod 998244353 #define maxn 300010 using namespace std; int n,m,dep[maxn],sz[maxn],son[maxn],fa[maxn],sum[maxn*4][51],top[maxn],dfn[maxn],id,v[maxn]; int head[maxn],num; struct node{int to,pre;}e[maxn*2]; void Insert(int from,int to){ e[++num].to=to; e[num].pre=head[from]; head[from]=num; } void dfs1(int x,int father){ dep[x]=dep[father]+1; fa[x]=father; sz[x]=1; for(int i=head[x];i;i=e[i].pre){ int to=e[i].to; if(to==father)continue; dfs1(to,x); sz[x]+=sz[to]; if(sz[son[x]]<sz[to])son[x]=to; } } void dfs2(int x,int father){ top[x]=father; dfn[x]=++id;v[id]=dep[x]; if(son[x]){dfs2(son[x],father);} for(int i=head[x];i;i=e[i].pre){ int to=e[i].to; if(to==fa[x]||to==son[x])continue; dfs2(to,to); } } int Pow(int x,int y){ int res=1; while(y){ if(y&1)res=1LL*res*x%mod; x=1LL*x*x%mod; y>>=1; } return res; } void build(int k,int l,int r,int mi){ if(l==r){ sum[k][mi]=Pow(v[l],mi); return; } int mid=(l+r)>>1; build(k<<1,l,mid,mi);build(k<<1|1,mid+1,r,mi); sum[k][mi]=(sum[k<<1][mi]+sum[k<<1|1][mi])%mod; } int query(int k,int l,int r,int opl,int opr,int mi){ if(l>=opl&&r<=opr)return sum[k][mi]; int mid=(l+r)>>1,res=0; if(opl<=mid)res+=query(k<<1,l,mid,opl,opr,mi); if(opr>mid)res+=query(k<<1|1,mid+1,r,opl,opr,mi); if(res>=mod)res-=mod; return res; } int query_sum(int x,int y,int z){ int res=0; while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]])swap(x,y); res+=query(1,1,n,dfn[top[x]],dfn[x],z); if(res>=mod)res-=mod; x=fa[top[x]]; } if(dep[x]>dep[y])swap(x,y); res+=query(1,1,n,dfn[x],dfn[y],z); if(res>=mod)res-=mod; return res; } int main(){ scanf("%d",&n); int x,y,z; for(int i=1;i<n;i++){ scanf("%d%d",&x,&y); Insert(x,y);Insert(y,x); } dep[0]=-1; dfs1(1,0);dfs2(1,1); for(int i=1;i<=50;i++)build(1,1,n,i); scanf("%d",&m); for(int i=1;i<=m;i++){ scanf("%d%d%d",&x,&y,&z); printf("%d\n",query_sum(x,y,z)); } return 0; }