[题解]CF622F The Sum of the k-th Powers

思路

首先发现 \(\sum_{i = 1}^{n}i^k\) 是一个 \(k + 1\) 次多项式,那么我们需要求出 \(k + 2\) 个点才能得到唯一的一个 \(f(t) = \sum_{i = 1}^{t}{i^k}\)

不难通过拉格朗日插值法,将 \(x = 1 \sim (k + 2)\) 的情况一一带入:

\[f(n) = \sum_{i = 1}^{k + 2}{((\sum_{j = 1}^{i}j^k) \times (\prod_{i \neq j}{\frac{n - x_j}{x_i - x_j}}))} \]

但是,普通的拉格朗日插值法是 \(\Theta(k^2)\),于是我们需要发掘本题中的特殊性。

可以轻易将原式转化为:

\[f(n) = \sum_{i = 1}^{k + 2}{((\sum_{j = 1}^{i}j^k) \times \frac{\prod_{i \neq j}{(n - x_j)}}{\prod_{i \neq j}{(x_i - x_j)}})} \]

发现 \(x \in [1,k + 2]\),那么容易转化:

\[f(n) = \sum_{i = 1}^{k + 2}{((\sum_{j = 1}^{i}j^k) \times \frac{\prod_{i \neq j}{(n - j)}}{\prod_{i \neq j}{(i - j)}})} \]

然后你对于 \(\prod\) 里面分数的分子、分母分别计算。

对于分子:

\[\prod_{i \neq j}{(n - x_j)} = \frac{\prod_{j = 1}^{k + 2}(n - j)}{n - i} \]

然后处理出 \(\prod_{j = 1}^{k + 2}(n - j)\) 即可。

对于分母:

\[\prod_{i \neq j}{(i - j)} = (\prod_{j = 1}^{i - 1}{j}) \times (\prod_{j = -1}^{i - k - 2}{j}) \]

定义 \(g(i) = (\prod_{j = 1}^{i - 1}{j}) \times (\prod_{j = -1}^{i - k - 2}{j})\),考虑 \(g(i)\)\(g(i - 1)\) 的关系。

发现前一个 \(\prod\)\(g(i)\)\(g(i - 1)\) 多乘以一个 \(i - 1\),后一个 \(\prod\)\(g(i)\)\(g(i - 1)\) 少乘一个 \(i - k - 3\)

因此 \(g(i) = g(i - 1) \times \frac{i - 1}{i - k - 3}\)。特别的 \(g(1) = \prod_{j = -1}^{-k - 1}j\)

将分子、分母代入原式即可。

观察到当 \(k + 2 \geq n\) 时,\(n - i\) 会被减成 \(0\),因此需要暴力 \(\Theta(n)\) 计算。

Code

#include <bits/stdc++.h>
#define re register
#define int long long
#define Add(a,b) ((((a) % mod + (b) % mod) % mod + mod) % mod)
#define Mul(a,b) ((((a) % mod) * ((b) % mod) % mod + mod) % mod)
#define Div(a,b) (Mul(a,qmi(((b) % mod + mod) % mod,mod - 2)))

using namespace std;

const int mod = 1e9 + 7;
int n,k,ans;
int mul = 1,g = 1,y;

inline int read(){
    int r = 0,w = 1;
    char c = getchar();
    while (c < '0' || c > '9'){
        if (c == '-') w = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9'){
        r = (r << 3) + (r << 1) + (c ^ 48);
        c = getchar();
    }
    return r * w;
}

inline int qmi(int a,int b){
    int res = 1;
    while (b){
        if (b & 1) res = Mul(res,a);
        a = Mul(a,a),b >>= 1;
    }
    return res;
}

inline void solve1(){
    for (re int i = 1;i <= n;i++) ans = Add(ans,qmi(i,k));
}

inline void solve2(){
    for (re int i = 1;i <= k + 2;i++) mul = Mul(mul,n - i);
    for (re int i = 1;i <= k + 2;i++){
        y = Add(y,qmi(i,k));
        if (i == 1){
            for (re int j = -1;j >= -k - 1;j--) g = Mul(g,j);
        }
        else g = Mul(g,Div(i - 1,i - k - 3));
        int a = Div(mul,n - i);
        ans = Add(ans,Mul(y,Div(a,g)));
    }
}

signed main(){
    n = read(),k = read();
    if (k + 2 >= n) solve1();
    else solve2();
    printf("%lld",ans);
    return 0;
}
posted @ 2024-06-23 13:00  WBIKPS  阅读(41)  评论(0)    收藏  举报