loj #2491. 「BJOI2018」求和

#2491. 「BJOI2018」求和

题目描述

master 对树上的求和非常感兴趣。他生成了一棵有根树,并且希望多次询问这棵树上一段路径上所有节点深度的 kkk 次方和,而且每次的 kkk 可能是不同的。此处节点深度的定义是这个节点到根的路径上的边数。 他把这个问题交给了 pupil,但 pupil 并不会这么复杂的操作,你能帮他解决吗?

输入格式

第一行包含一个正整数 nnn,表示树的节点数。

之后 n−1n-1n1 行每行两个空格隔开的正整数 i,ji,ji,j,表示树上的一条连接点 iii 和点 jjj 的边。

之后一行一个正整数 mmm,表示询问的数量。

之后每行三个空格隔开的正整数 i,j,ki,j,ki,j,k,表示询问从点 iii 到点 jjj 的路径上所有节点深度的 kkk 次方和。由于这个结果可能非常大,输出其对 998244353998244353998244353取模的结果。

树的节点从 111 开始标号,其中 111 号节点为树的根。

输出格式

对于每组数据输出一行一个正整数表示取模后的结果。

样例

样例输入

5
1 2
1 3
2 4
2 5
2
1 4 5
5 4 45

样例输出

33
503245989

样例解释

以下用 d(i)d\left(i\right)d(i) 表示第 iii 个节点的深度。

对于样例中的树,有 d(1)=0,d(2)=1,d(3)=1,d(4)=2,d(5)=2d\left(1\right)=0,d\left(2\right)=1,d\left(3\right)=1,d\left(4\right)=2,d\left(5\right)=2d(1)=0,d(2)=1,d(3)=1,d(4)=2,d(5)=2。

因此第一个询问答案为 (25+15+05)mod998244353=33,第二个询问答案为 (245+145+245)mod998244353=503245989。

数据范围与提示

对于30%30\%30%的数据,1≤n,m≤1001 \leq n,m \leq 1001n,m100;

对于60%60\%60%的数据,1≤n,m≤10001 \leq n,m \leq 10001n,m1000;

对于100%100\%100%的数据,1≤n,m≤300000,1≤k≤501 \leq n,m \leq 300000,1 \leq k \leq 501n,m300000,1k50。

 

 

/*
    可以说是树剖裸题了,k非常小,直接预处理即可
*/
#include<iostream>
#include<cstdio>
#include<cstring>
#define mod 998244353
#define maxn 300010
using namespace std;
int n,m,dep[maxn],sz[maxn],son[maxn],fa[maxn],sum[maxn*4][51],top[maxn],dfn[maxn],id,v[maxn];
int head[maxn],num;
struct node{int to,pre;}e[maxn*2];
void Insert(int from,int to){
    e[++num].to=to;
    e[num].pre=head[from];
    head[from]=num;
}
void dfs1(int x,int father){
    dep[x]=dep[father]+1;
    fa[x]=father;
    sz[x]=1;
    for(int i=head[x];i;i=e[i].pre){
        int to=e[i].to;
        if(to==father)continue;
        dfs1(to,x);
        sz[x]+=sz[to];
        if(sz[son[x]]<sz[to])son[x]=to;
    }
}
void dfs2(int x,int father){
    top[x]=father;
    dfn[x]=++id;v[id]=dep[x];
    if(son[x]){dfs2(son[x],father);}
    for(int i=head[x];i;i=e[i].pre){
        int to=e[i].to;
        if(to==fa[x]||to==son[x])continue;
        dfs2(to,to);
    }
}
int Pow(int x,int y){
    int res=1;
    while(y){
        if(y&1)res=1LL*res*x%mod;
        x=1LL*x*x%mod;
        y>>=1;
    }
    return res;
}
void build(int k,int l,int r,int mi){
    if(l==r){
        sum[k][mi]=Pow(v[l],mi);
        return;
    }
    int mid=(l+r)>>1;
    build(k<<1,l,mid,mi);build(k<<1|1,mid+1,r,mi);
    sum[k][mi]=(sum[k<<1][mi]+sum[k<<1|1][mi])%mod;
}
int query(int k,int l,int r,int opl,int opr,int mi){
    if(l>=opl&&r<=opr)return sum[k][mi];
    int mid=(l+r)>>1,res=0;
    if(opl<=mid)res+=query(k<<1,l,mid,opl,opr,mi);
    if(opr>mid)res+=query(k<<1|1,mid+1,r,opl,opr,mi);
    if(res>=mod)res-=mod;
    return res;
}
int query_sum(int x,int y,int z){
    int res=0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        res+=query(1,1,n,dfn[top[x]],dfn[x],z);
        if(res>=mod)res-=mod;
        x=fa[top[x]];
    }
    if(dep[x]>dep[y])swap(x,y);
    res+=query(1,1,n,dfn[x],dfn[y],z);
    if(res>=mod)res-=mod;
    return res;
}
int main(){
    scanf("%d",&n);
    int x,y,z;
    for(int i=1;i<n;i++){
        scanf("%d%d",&x,&y);
        Insert(x,y);Insert(y,x);
    }
    dep[0]=-1;
    dfs1(1,0);dfs2(1,1);
    for(int i=1;i<=50;i++)build(1,1,n,i);
    scanf("%d",&m);
    for(int i=1;i<=m;i++){
        scanf("%d%d%d",&x,&y,&z);
        printf("%d\n",query_sum(x,y,z));
    }
    return 0;
}

 

posted @ 2018-04-15 08:24  Echo宝贝儿  阅读(502)  评论(0)    收藏  举报