2025.10.26 闲话:单位根反演
2025.10.26 闲话-单位根反演
起因正在和 zxk 探讨 k 叉 bostan-mori。
jijidawang:直接单位根反演。
所以就来学习单位根反演了。
Part.1 主体
首先引入这样一个问题:
求:
可以构造 \(f(x)=\sum_{i=0}^{n}{\binom{n}{i}x^i}\)
然后要求的是所有偶数项的系数之和,发现如果将 \(-x\) 带入,恰好所有的奇数项均被翻转,那么有:
所求即为 \(\frac{f(1)+f(-1)}{2}\)
另一个问题,如何求:
发现不好找一个合适的数使得非奇数项被干掉,但是通过数学知识可得:
\(w_3^0+w_3^1+w_3^2=0\) 这启发我们:
那么不妨推广一下:
考虑怎么证明首先发现后面是一个等差数列,先去除 1 的情况,即 \(w_{k}^{n}=1\),此时可得 \(k|n\),带入成立。
否则:
以上便是单位根反演,更常用的是放到多项式中:
注意到 \(\sum_{i=0}^{n}{a_i(w_k^j)^i}=f(w_k^j)\)
所以:
十分优美。
模意义下的原根等价于单位根,不过我不会原根。
容易发现一件事情,如果 \(k\) 为 \(2^t\),那么 \(f(w_k^j)\) 其实就是进行 NTT 后的系数。
扩展一点,如果将 \([k|i]\) 改为 \(i\bmod k=t\) 怎么做。
容易发现,将 \(f(x)\) 变为 \(x^{k-t}f(x)\) 后与原问题等价。
Part.2 例题
luogu P10664 PYXFIB
模板题,上述过程可以推广到矩阵。
设 \(I\) 为单位矩阵,\(F,G\) 为矩阵求斐波那契转移矩阵。
复杂度 \(O(k\log n)\)(求原根不算在内),注意处理原根。
luogu P5591 小猪佩奇学数学
这题有一万种做法,不过我自己搞了一种简单易懂的做法,通过看题解学会了另一种更简单的做法,那做法吊打了我的做法。
首先推式子:
\( \begin{aligned} ans &= \sum_{i=0}^n \binom n ip^{i}\left\lfloor \frac{i}{k} \right\rfloor \\ &= \sum_{i=0}^n \binom n i p^{i} \frac{i-(i\bmod k)}{k} \\ &= \sum_{i=0}^n \binom n i p^{i} \frac{i}{k} - \sum_{i=0}^n \binom n i p^{i} \frac{(i\bmod k)}{k} \\ &= \frac{1}{k}\sum_{i=0}^n \frac{n!}{i!(n-i)!} p^{i} i - \frac{1}{k}\sum_{i=0}^n \binom n i p^{i} (i\bmod k) \\ &= \frac{np}{k}\sum_{i=1}^n \frac{(n-1)!}{(i-1)!(n-i)!} p^{i-1} - \frac{1}{k}\sum_{t=0}^{k-1}\sum_{i=0}^n \binom n i p^{i} t [i\bmod k=t] \\ &= \frac{np}{k}(p+1)^{n-1} - \frac{1}{k}\sum_{t=0}^{k-1}\sum_{i=0}^n \binom n i p^{i} t [i\bmod k=t] \\ \end{aligned} \)
先讲我的做法,主要是处理后面:
设 \(f(x)=\sum_{i=0}^{n}\binom{n}{i}p^ix^{n-i}=(1+xp)^n\)。
那么可得:
\( \begin{aligned} ans' &= \frac{1}{k}\sum_{t=0}^{k-1}\sum_{i=0}^n \binom n i p^{i} t [i\bmod k=t] \\ &= \frac{1}{k^2}\sum_{i=0}^{k-1}\sum_{t=0}^{k-1}f(w_k^i)(w_k^i)^{k-t}t \\ \end{aligned} \)
此时需要将 \(w_k^i\) 替换为 \(x\),需要处理:
发现只需要进行两次 NTT,然后将两个多项式的点值表示乘起来即可,最后:
复杂度 \(O(k\log k)\)。
做法 1
#include <iostream>
#include <set>
#include <vector>
#include <algorithm>
#include <queue>
#include <cstring>
#include <unordered_map>
#include <map>
#include <ctime>
using namespace std;
const int N = 4e6 + 10, mod = 998244353;
#define int long long
#define emp emplace_back
#define pb push_back
#define fi first
#define se second
using pii = pair <int, int>;
int fac[N], inv[N], s[N], p[N], f[N], g[N];
namespace Poly
{
int qpow(int x, int b)
{
int res = 1;
while (b)
{
if (b & 1) res = res * x % mod;
x = x * x % mod;
b >>= 1;
}
return res;
}
int rev[N];
void NTT(int *a, int k, bool op = 0)
{
for (int i = 0; i < k; i++) rev[i] = (rev[i >> 1] >> 1) | (i & 1 ? k >> 1 : 0);
for (int i = 0; i < k; i++) if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int len = 2; len <= k; len <<= 1)
{
int wn = qpow(3, (mod - 1) / len);
for (int l = 0, mid = (len >> 1) - 1; l + len - 1 < k; l += len, mid += len)
{
int w = 1;
for (int i = l; i <= mid; i++, w = w * wn % mod)
{
int x = a[i], y = a[i + (len >> 1)] * w % mod;
a[i] = x + y;
if (a[i] >= mod) a[i] -= mod;
a[i + (len >> 1)] = x - y;
if (a[i + (len >> 1)] < 0) a[i + (len >> 1)] += mod;
}
}
}
if (op)
{
reverse(a + 1, a + k);
int inv = qpow(k, mod - 2);
for (int i = 0; i < k; i++) a[i] = a[i] * inv % mod;
}
}
void fill(int *f, int l, int r, int v) {for (int i = l; i < min((long long)N, r); i++) f[i] = v;}
void copy(int *f, int *h, int l, int r) {for (int i = l; i < r; i++) h[i] = f[i];}
int mulf[N], mulg[N];
void mul(int *f, int *g, int *h, int n, int m)
{
int len = 1;
while (len < n + m) len <<= 1;
fill(mulf, 0, len, 0), fill(mulg, 0, len, 0);
copy(f, mulf, 0, len), copy(g, mulg, 0, len);
NTT(mulf, len, 0), NTT(mulg, len, 0);
for (int i = 0; i < len; i++) h[i] = mulf[i] * mulg[i] % mod;
NTT(h, len, 1);
for (int i = n + m - 1; i < len; i++) h[i] = 0;
}
int invh[N], invf[N];
void Inv(int *f, int *h, int n)
{
if (n == 1) return h[0] = qpow(f[0], mod - 2), void();
Inv(f, h, (n + 1) >> 1);
int len = 1;
while (len < 2 * n) len <<= 1;
fill(invf, 0, len, 0);
copy(f, invf, 0, n);
NTT(invf, len, 0), NTT(h, len, 0);
for (int i = 0; i < len; i++) h[i] = h[i] * (2 - h[i] * invf[i] % mod + mod) % mod;
NTT(h, len, 1);
fill(h, n, len, 0);
}
void dev(int *f, int len) {for (int i = 1; i < len; i++) f[i - 1] = i * f[i] % mod; f[len - 1] = 0;}
void redev(int *f, int len) {for (int i = len - 1; i >= 0; i--) f[i + 1] = f[i] * qpow(i + 1, mod - 2) % mod; f[0] = 0;}
int lnf[N], lng[N];
void ln(int *f, int *h, int n)
{
fill(lnf, 0, 4 * n, 0), fill(lng, 0, 4 * n, 0);
copy(f, lnf, 0, n);
dev(lnf, n);
Inv(f, lng, n);
mul(lnf, lng, h, n, n);
redev(h, n);
fill(h, n, 2 * n, 0);
}
int _expf[N], expg[N];
void exp(int *f, int *h, int n)
{
if (n == 1) return h[0] = 1, void();
exp(f, h, (n + 1) >> 1);
fill(_expf, 0, 2 * n, 0);
fill(expg, 0, 2 * n, 0);
copy(h, _expf, 0, n);
fill(_expf, n, 2 * n, 0);
ln(_expf, expg, n);
for (int i = 0; i < n; i++) expg[i] = (-expg[i] + f[i] + mod) % mod;
expg[0]++;
mul(_expf, expg, h, n, n);
fill(h, n, 2 * n, 0);
}
}using namespace Poly;
int C(int n, int m) {return n >= m ? fac[n] * inv[m] % mod * inv[n - m] % mod : 0;}
int A(int n, int m) {return C(n, m) * fac[m] % mod;}
signed main()
{
// freopen("data.in", "r", stdin); freopen("data.out", "w", stdout);
// freopen("mission.in", "r", stdin); freopen("mission.out", "w", stdout);
ios :: sync_with_stdio(false), cin.tie(0), cout.tie(0);
int n, p, k; cin >> n >> p >> k;
f[0] = 1, f[1] = p;
for (int i = 0; i < k; i++) g[i] = k - i;
g[0] = 0;
NTT(f, k, 0), NTT(g, k, 0);
for (int i = 0; i < k; i++) f[i] = qpow(f[i], n) * g[i] % mod;
int ans = 0, invk = qpow(k, mod - 2);
for (int i = 0; i < k; i++) ans = (ans + mod - invk * f[i] % mod) % mod;
ans = (ans + n * p % mod * qpow(p + 1, n - 1)) % mod;
cout << ans * invk % mod;
return 0;
}
更为简单的做法是发现 NTT 的操作是模意义下的,也就是如果多项式系数超过了 NTT 中的长度,那么会累加到 \(i\bmod k\) 上,发现这和上述问题中求的东西恰好匹配上了,所以直接对 \(f(x)\) NTT 然后 INTT 就对了。
做法 2
#include <iostream>
#include <set>
#include <vector>
#include <algorithm>
#include <queue>
#include <cstring>
#include <unordered_map>
#include <map>
#include <ctime>
using namespace std;
const int N = 4e6 + 10, mod = 998244353;
#define int long long
#define emp emplace_back
#define pb push_back
#define fi first
#define se second
using pii = pair <int, int>;
int fac[N], inv[N], s[N], p[N], f[N], g[N];
namespace Poly
{
int qpow(int x, int b)
{
int res = 1;
while (b)
{
if (b & 1) res = res * x % mod;
x = x * x % mod;
b >>= 1;
}
return res;
}
int rev[N];
void NTT(int *a, int k, bool op = 0)
{
for (int i = 0; i < k; i++) rev[i] = (rev[i >> 1] >> 1) | (i & 1 ? k >> 1 : 0);
for (int i = 0; i < k; i++) if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int len = 2; len <= k; len <<= 1)
{
int wn = qpow(3, (mod - 1) / len);
for (int l = 0, mid = (len >> 1) - 1; l + len - 1 < k; l += len, mid += len)
{
int w = 1;
for (int i = l; i <= mid; i++, w = w * wn % mod)
{
int x = a[i], y = a[i + (len >> 1)] * w % mod;
a[i] = x + y;
if (a[i] >= mod) a[i] -= mod;
a[i + (len >> 1)] = x - y;
if (a[i + (len >> 1)] < 0) a[i + (len >> 1)] += mod;
}
}
}
if (op)
{
reverse(a + 1, a + k);
int inv = qpow(k, mod - 2);
for (int i = 0; i < k; i++) a[i] = a[i] * inv % mod;
}
}
void fill(int *f, int l, int r, int v) {for (int i = l; i < min((long long)N, r); i++) f[i] = v;}
void copy(int *f, int *h, int l, int r) {for (int i = l; i < r; i++) h[i] = f[i];}
int mulf[N], mulg[N];
void mul(int *f, int *g, int *h, int n, int m)
{
int len = 1;
while (len < n + m) len <<= 1;
fill(mulf, 0, len, 0), fill(mulg, 0, len, 0);
copy(f, mulf, 0, len), copy(g, mulg, 0, len);
NTT(mulf, len, 0), NTT(mulg, len, 0);
for (int i = 0; i < len; i++) h[i] = mulf[i] * mulg[i] % mod;
NTT(h, len, 1);
for (int i = n + m - 1; i < len; i++) h[i] = 0;
}
int invh[N], invf[N];
void Inv(int *f, int *h, int n)
{
if (n == 1) return h[0] = qpow(f[0], mod - 2), void();
Inv(f, h, (n + 1) >> 1);
int len = 1;
while (len < 2 * n) len <<= 1;
fill(invf, 0, len, 0);
copy(f, invf, 0, n);
NTT(invf, len, 0), NTT(h, len, 0);
for (int i = 0; i < len; i++) h[i] = h[i] * (2 - h[i] * invf[i] % mod + mod) % mod;
NTT(h, len, 1);
fill(h, n, len, 0);
}
void dev(int *f, int len) {for (int i = 1; i < len; i++) f[i - 1] = i * f[i] % mod; f[len - 1] = 0;}
void redev(int *f, int len) {for (int i = len - 1; i >= 0; i--) f[i + 1] = f[i] * qpow(i + 1, mod - 2) % mod; f[0] = 0;}
int lnf[N], lng[N];
void ln(int *f, int *h, int n)
{
fill(lnf, 0, 4 * n, 0), fill(lng, 0, 4 * n, 0);
copy(f, lnf, 0, n);
dev(lnf, n);
Inv(f, lng, n);
mul(lnf, lng, h, n, n);
redev(h, n);
fill(h, n, 2 * n, 0);
}
int _expf[N], expg[N];
void exp(int *f, int *h, int n)
{
if (n == 1) return h[0] = 1, void();
exp(f, h, (n + 1) >> 1);
fill(_expf, 0, 2 * n, 0);
fill(expg, 0, 2 * n, 0);
copy(h, _expf, 0, n);
fill(_expf, n, 2 * n, 0);
ln(_expf, expg, n);
for (int i = 0; i < n; i++) expg[i] = (-expg[i] + f[i] + mod) % mod;
expg[0]++;
mul(_expf, expg, h, n, n);
fill(h, n, 2 * n, 0);
}
}using namespace Poly;
int C(int n, int m) {return n >= m ? fac[n] * inv[m] % mod * inv[n - m] % mod : 0;}
int A(int n, int m) {return C(n, m) * fac[m] % mod;}
signed main()
{
// freopen("data.in", "r", stdin); freopen("data.out", "w", stdout);
// freopen("mission.in", "r", stdin); freopen("mission.out", "w", stdout);
ios :: sync_with_stdio(false), cin.tie(0), cout.tie(0);
int n, p, k; cin >> n >> p >> k;
f[0] = 1, f[1] = p;
for (int i = 0; i < k; i++) g[i] = k - i;
g[0] = 0;
NTT(f, k, 0), NTT(g, k, 0);
for (int i = 0; i < k; i++) f[i] = qpow(f[i], n) * g[i] % mod;
int ans = 0;
for (int i = 0; i < k; i++) ans = (ans + mod - qpow(k, mod - 2) * f[i] % mod) % mod;
ans = (ans + n * p % mod * qpow(p + 1, n - 1)) % mod;
cout << ans * qpow(k, mod - 2) % mod;
return 0;
}
速度相差不大。

浙公网安备 33010602011771号