P4427 [BJOI2018] 求和

P4427 [BJOI2018] 求和

题目描述

master 对树上的求和非常感兴趣。他生成了一棵有根树,并且希望多次询问这棵树上一段路径上所有节点深度的 \(k\) 次方和,而且每次的 \(k\) 可能是不同的。此处节点深度的定义是这个节点到根的路径上的边数。他把这个问题交给了 pupil,但 pupil 并不会这么复杂的操作,你能帮他解决吗?

输入格式

第一行包含一个正整数 \(n\),表示树的节点数。

之后 \(n-1\) 行每行两个空格隔开的正整数 \(i, j\),表示树上的一条连接点 \(i\) 和点 \(j\) 的边。

之后一行一个正整数 \(m\),表示询问的数量。

之后每行三个空格隔开的正整数 \(i, j, k\),表示询问从点 \(i\) 到点 \(j\) 的路径上所有节点深度的 \(k\) 次方和。由于这个结果可能非常大,输出其对 \(998244353\) 取模的结果。

树的节点从 \(1\) 开始标号,其中 \(1\) 号节点为树的根。

数据范围

对于 \(100\%\) 的数据,\(1 \leq n,m \leq 300000\)\(1 \leq k \leq 50\)

Solution:

\(1 \leq k \leq 50\)。这令我们十分有感觉。不难想到,对于一条路径 (x,y) ,将其分成两部分,(x,lca),(lca,y)。在这两段上的点的深度是连续的。所以我们不难想到 \(O(nk)\) 打表算出 \(dep^k\) 然后求一个前缀和,就可以轻松 A 掉这道小清新绿题了 -.

Code:

#include<bits/stdc++.h>
#define int long long
const int lg=19;
const int N=3e5+5;
const int mod=998244353;
using namespace std;
int n,m;
int mul(int x,int y){return (x*y)%mod;}
int add(int x,int y){return (x+y)%mod;}
int ksm(int x,int k)
{
    if(!x)return 0;
    int res=1;
    while(k)
    {
        if(k&1)res=mul(res,x);
        x=mul(x,x);k>>=1;
    }
    return res;
}
int sum[55][N];
void init(int k)
{
    for(int i=1;i<=n;i++)
    {
        sum[k][i]=ksm(i,k);
        sum[k][i]=add(sum[k][i],sum[k][i-1]);
    }
}
vector<int> E[N];
int dep[N],f[N][lg+5];
void dfs(int x,int fa)
{
    f[x][0]=fa;
    for(int j=1;j<=lg;j++)f[x][j]=f[f[x][j-1]][j-1];
    for(int y : E[x]){if(y!=fa){dep[y]=dep[x]+1;dfs(y,x);}}
}
int LCA(int x,int y)
{
    if(dep[x]<dep[y])swap(x,y);
    for(int i=lg;i>=0;i--)if(dep[f[x][i]]>=dep[y])x=f[x][i];
    for(int i=lg;i>=0;i--)if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i];
    return x==y ? x : f[x][0];
}
void work()
{
    cin>>n;
    for(int i=1,x,y;i<n;i++)
    {
        scanf("%lld%lld",&x,&y);
        E[x].push_back(y);
        E[y].push_back(x);
    }
    dfs(1,0);
    cin>>m;
    for(int i=1;i<=50;i++)init(i);
    for(int i=1,x,y,k;i<=m;i++)
    {
        scanf("%lld%lld%lld",&x,&y,&k);
        int lca=LCA(x,y);
        //cout<<dep[x]<<" "<<dep[y]<<" "<<dep[lca]<<" "<<dep[f[lca][0]]<<"\n";
        //cout<<x<<" "<<y<<" "<<lca<<" "<<f[lca][0]<<"\n";
        int ans=add(add(sum[k][dep[x]],sum[k][dep[y]]),add(-sum[k][dep[lca]]-sum[k][dep[f[lca][0]]],mod<<1));
        printf("%lld\n",ans);
    }
}
#undef int
int main()
{
    //freopen("sum.in","r",stdin);freopen("sum.out","w",stdout);
    work();
    return 0;
}
posted @ 2025-01-24 16:25  liuboom  阅读(28)  评论(0)    收藏  举报