单位根反演

考虑单位根有这样一个性质:

\[[k|n]={1\over k}\sum_{d=0}^{k-1}\omega_k^{dn} \]

它有什么用呢?考虑处理一种问题,它式子里带形如 \(a_{i\bmod k}\) 的东西。我们可以枚举取模的结果,写成如下形式:

\[a_{i\bmod k} = \sum_{j=0}^{k-1}a_j [k|i-j]=\sum_{j=0}^{k-1}a_j{1\over k}\sum_{d=0}^{k-1}\omega_k^{d(i-j)} \]

这样就消掉了下标里烦人的取模,更方便推式子,但对模数有很高要求(要有好算的原根)。

应用: P5591 小猪佩奇学数学

求: $$\sum_{i=0}^n \binom{n}{i}p^i \lfloor {i\over k} \rfloor$$

不难注意到 \(\lfloor{i\over k}\rfloor = {1\over k} (i - i\bmod k)\),我们把\(1\over k\) 提前,把式子拆成两部分:

\[{1\over k}(\sum_{i=0}^n \binom{n}{i}ip^i - \sum_{i=0}^n \binom{n}{i}p^i (i\bmod k)) \]

前一部分看着比较阳间,考虑 \(i \binom{n}{i}=n\binom{n-1}{i-1}\),则有:

\[\sum_{i=0}^n \binom{n}{i}ip^i=\sum_{i=1}^n np \binom{n-1}{i-1}p^{i-1}=np(p+1)^{n-1} \]

考虑后面一部分有取模,用单位根反演:

\[\sum_{i=0}^n \binom{n}{i}p^i (i\bmod k)=\sum_{j=0}^{k-1}\sum_{i=0}^n \binom{n}{i}p^ij[k|i-j] \]

\[={1\over k} \sum_{j=0}^{k-1}\sum_{i=0}^n \binom{n}{i}p^i j \sum_{d=0}^{k-1}\omega_k^{d(i-j)} \]

推式子尽量把范围小的\(\Sigma\)放里面,只与某个变量有关的项提出去——XK

考虑后面那一坨

\[\sum_{j=0}^{k-1}\sum_{i=0}^n \binom{n}{i}p^i j \sum_{d=0}^{k-1}\omega_k^{d(i-j)} \]

\[=\sum_{d=0}^{k-1}\sum_{j=0}^{k-1}j\omega_k^{-dj} \sum_{i=0}^n \binom{n}{i} p^i\omega_k^{di} \]

考虑后面可以二项式定理

\[=\sum_{d=0}^{k-1}\sum_{j=0}^{k-1}j\omega_k^{-dj}(p\omega_k^d+1)^n \]

\((p\omega_k^d+1)^n\)\(j\)实际上没有关系,可以提前

\[=\sum_{d=0}^{k-1}(p\omega_k^d+1)^n\sum_{j=0}^{k-1}j\omega_k^{-dj} \]

后面是差比数列求和,可以\(\Theta(\log k)\)算,复杂度 \(\Theta(k\log k)\)

#include <bits/stdc++.h>

using namespace std;

const int mod = 998244353;

inline int power(int a, int b) {
	int t = 1, y = a, k = b;
	while (k) {
		if (k & 1) t = (1ll * t * y) % mod;
		y = (1ll * y * y) % mod; k >>= 1;
	} return t;
}

inline int calc(int a, int n) {
	if (a == 0 ) return 0;
	if (a == 1) return n;
	int res = power(a, n + 1) - 1;
	if (res < 0) res += mod;
	res = (1ll * res * power(a - 1, mod - 2)) % mod - 1;
	if (res < 0) res += mod;
	return res;
}

inline int sum(int k, int n) {
	if (k == 0) return 0;
	if (k == 1) return (1ll * n * (n + 1) / 2) % mod;
	int res = (1ll * n * power(k, n + 1)) % mod;
	res -= calc(k, n); if (res < 0) res += mod;
	res = (1ll * res * power(k - 1, mod - 2)) % mod;
	return res;
}

int n, p, k, ans;

int main() {
	cin >> n >> p >> k;
	ans = (1ll * ((1ll * n * p) % mod) * power(p + 1, n - 1)) % mod;
	int wk = power(3, (mod - 1) / k);
	for (int d = 0; d < k; ++d) {
		int res = (1ll * p * power(wk, d)) % mod;
		res = power(res + 1, n);
		res = (1ll * res * sum(power(power(wk, d), mod - 2), k - 1)) % mod;
		ans -= (1ll * res * power(k, mod - 2)) % mod;
		if (ans < 0) ans += mod;
	} ans = (1ll * power(k, mod - 2) * ans) % mod;
	printf("%d\n", ans);
	return 0;
} 
posted @ 2022-08-24 09:13  Smallbasic  阅读(36)  评论(0)    收藏  举报