洛谷 P10084 [GDKOI2024 提高组] 计算
第一步是一个经典结论,\(L = m^{\gcd(a, b)} + 1\),\(R = m^{\gcd(c, d)}\)。
因为 \(L \equiv 1 \pmod m\) 且 \(R \equiv 0 \pmod m\),所以可以把问题的范围改成 \([1, n = R - L + 1]\)。
写出选数的生成函数:
我们希望求所有次数是 \(m\) 的倍数的项的系数之和。
施单位根反演:
考虑对于一个给定的 \(j\) 如何计算 \(F(\omega_m^j)\) 即 \(\prod\limits_{i = 1}^n (1 + (\omega_m^j)^i)\)。
设 \(d = \gcd(m, j)\)。因为有 \(\omega_m^k = \omega_m^{k \bmod m}\),又因为 \(\frac{j}{d}, \frac{2j}{d}, \ldots, \frac{nj}{d}\) 形成了 \(\frac{n}{\frac{m}{d}}\) 个模 \(\frac{m}{d}\) 的剩余系,所以:
考虑求这样一个式子:\(\prod\limits_{i = 0}^{n - 1} (1 + \omega_n^i)\)。考虑分圆多项式:
代入 \(x = -1\) 得:
也就是当 \(\frac{m}{d}\) 为偶数时,\(F(\omega_m^j) = 0\);否则 \(F(\omega_m^j) = 2^{\frac{n}{\frac{m}{d}}}\)。
把 \(F(\omega_m^j)\) 代入到答案的式子中,有:
直接计算,时间复杂度 \(O(Tm \log n)\),无法通过。
考虑一些 trivial 的优化,枚举 \(d = \gcd(m, j)\),\([0, m - 1]\) 中和 \(m\) 的 \(\gcd\) 为 \(d\) 的数的个数显然为 \(\varphi(\frac{m}{d})\),那么:
先线性筛出全部 \(\varphi(i)\) 即可计算。时间复杂度降为 \(O(T(\sqrt{m} + d(m) \log n))\),可以通过。
code
// Problem: P10084 [GDKOI2024 提高组] 计算
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P10084
// Memory Limit: 512 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 10000100;
const int N = 10000000;
const ll mod = 998244353;
ll n, m, pr[maxn / 10], tot, phi[maxn];
bool vis[maxn];
inline ll qpow(ll b, ll p) {
ll res = 1;
while (p) {
if (p & 1) {
res = res * b % mod;
}
b = b * b % mod;
p >>= 1;
}
return res;
}
inline void init() {
phi[1] = 1;
for (int i = 2; i <= N; ++i) {
if (!vis[i]) {
pr[++tot] = i;
phi[i] = i - 1;
}
for (int j = 1; j <= tot && i * pr[j] <= N; ++j) {
vis[i * pr[j]] = 1;
if (i % pr[j] == 0) {
phi[i * pr[j]] = phi[i] * pr[j];
break;
}
phi[i * pr[j]] = phi[i] * (pr[j] - 1);
}
}
}
inline ll work(ll d) {
if ((m / d) % 2 == 0) {
return 0;
}
return qpow(2, n / (m / d)) * phi[m / d] % mod;
}
void solve() {
ll _a, _b, _c, _d;
scanf("%lld%lld%lld%lld%lld", &m, &_a, &_b, &_c, &_d);
ll _x = 1, _y = 1;
for (int _ = 0; _ < __gcd(_c, _d); ++_) {
_x *= m;
}
for (int _ = 0; _ < __gcd(_a, _b); ++_) {
_y *= m;
}
n = _x - _y;
if (!_a || !_b) {
n = _x;
}
ll ans = 0;
for (ll i = 1; i * i <= m; ++i) {
if (m % i) {
continue;
}
ans = (ans + work(i)) % mod;
if (i * i != m) {
ans = (ans + work(m / i)) % mod;
}
}
printf("%lld\n", ans * qpow(m, mod - 2) % mod);
}
int main() {
init();
int T = 1;
scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}