单位根反演
考虑单位根有这样一个性质:
\[[k|n]={1\over k}\sum_{d=0}^{k-1}\omega_k^{dn}
\]
它有什么用呢?考虑处理一种问题,它式子里带形如 \(a_{i\bmod k}\) 的东西。我们可以枚举取模的结果,写成如下形式:
\[a_{i\bmod k} = \sum_{j=0}^{k-1}a_j [k|i-j]=\sum_{j=0}^{k-1}a_j{1\over k}\sum_{d=0}^{k-1}\omega_k^{d(i-j)}
\]
这样就消掉了下标里烦人的取模,更方便推式子,但对模数有很高要求(要有好算的原根)。
应用: P5591 小猪佩奇学数学
求: $$\sum_{i=0}^n \binom{n}{i}p^i \lfloor {i\over k} \rfloor$$
不难注意到 \(\lfloor{i\over k}\rfloor = {1\over k} (i - i\bmod k)\),我们把\(1\over k\) 提前,把式子拆成两部分:
\[{1\over k}(\sum_{i=0}^n \binom{n}{i}ip^i - \sum_{i=0}^n \binom{n}{i}p^i (i\bmod k))
\]
前一部分看着比较阳间,考虑 \(i \binom{n}{i}=n\binom{n-1}{i-1}\),则有:
\[\sum_{i=0}^n \binom{n}{i}ip^i=\sum_{i=1}^n np \binom{n-1}{i-1}p^{i-1}=np(p+1)^{n-1}
\]
考虑后面一部分有取模,用单位根反演:
\[\sum_{i=0}^n \binom{n}{i}p^i (i\bmod k)=\sum_{j=0}^{k-1}\sum_{i=0}^n \binom{n}{i}p^ij[k|i-j]
\]
\[={1\over k} \sum_{j=0}^{k-1}\sum_{i=0}^n \binom{n}{i}p^i j \sum_{d=0}^{k-1}\omega_k^{d(i-j)}
\]
推式子尽量把范围小的\(\Sigma\)放里面,只与某个变量有关的项提出去——XK
考虑后面那一坨
\[\sum_{j=0}^{k-1}\sum_{i=0}^n \binom{n}{i}p^i j \sum_{d=0}^{k-1}\omega_k^{d(i-j)}
\]
\[=\sum_{d=0}^{k-1}\sum_{j=0}^{k-1}j\omega_k^{-dj} \sum_{i=0}^n \binom{n}{i} p^i\omega_k^{di}
\]
考虑后面可以二项式定理
\[=\sum_{d=0}^{k-1}\sum_{j=0}^{k-1}j\omega_k^{-dj}(p\omega_k^d+1)^n
\]
\((p\omega_k^d+1)^n\)和\(j\)实际上没有关系,可以提前
\[=\sum_{d=0}^{k-1}(p\omega_k^d+1)^n\sum_{j=0}^{k-1}j\omega_k^{-dj}
\]
后面是差比数列求和,可以\(\Theta(\log k)\)算,复杂度 \(\Theta(k\log k)\)
#include <bits/stdc++.h>
using namespace std;
const int mod = 998244353;
inline int power(int a, int b) {
int t = 1, y = a, k = b;
while (k) {
if (k & 1) t = (1ll * t * y) % mod;
y = (1ll * y * y) % mod; k >>= 1;
} return t;
}
inline int calc(int a, int n) {
if (a == 0 ) return 0;
if (a == 1) return n;
int res = power(a, n + 1) - 1;
if (res < 0) res += mod;
res = (1ll * res * power(a - 1, mod - 2)) % mod - 1;
if (res < 0) res += mod;
return res;
}
inline int sum(int k, int n) {
if (k == 0) return 0;
if (k == 1) return (1ll * n * (n + 1) / 2) % mod;
int res = (1ll * n * power(k, n + 1)) % mod;
res -= calc(k, n); if (res < 0) res += mod;
res = (1ll * res * power(k - 1, mod - 2)) % mod;
return res;
}
int n, p, k, ans;
int main() {
cin >> n >> p >> k;
ans = (1ll * ((1ll * n * p) % mod) * power(p + 1, n - 1)) % mod;
int wk = power(3, (mod - 1) / k);
for (int d = 0; d < k; ++d) {
int res = (1ll * p * power(wk, d)) % mod;
res = power(res + 1, n);
res = (1ll * res * sum(power(power(wk, d), mod - 2), k - 1)) % mod;
ans -= (1ll * res * power(k, mod - 2)) % mod;
if (ans < 0) ans += mod;
} ans = (1ll * power(k, mod - 2) * ans) % mod;
printf("%d\n", ans);
return 0;
}