[计数] [数论] AT_abc245_h [ABC245Ex] Product Modulo 2
posted on 2025-02-26 08:35:27 | under | source
题意:序列 \(a_1\dots a_{len}\) 值域 \(a_i\in [0,m)\),求满足 \(\prod a_i \bmod m=n\) 的序列个数。\(0\le n<m\le 10^{12},k\le 10^9\)。
根据 CRT 的结论,确定 \(x\) 模 \(m\) 各个质因子下的结果,即可唯一确定 \(x\) 模 \(m\) 的结果。于是拆开来对每个质因子次幂分别计算,最后再乘起来,这样构成双射所以不重不漏。
设 \(m=p^k,n=C\times p^c\)。为了方便起见,特判掉 \(n=0\) 也就是 \(c=k\) 的情况,那么接下来 \(c<k\)。
经典套路之考虑最后一位。设 \(\prod\limits_{i\in[1,len)}a_i=A\times p^a\),而 \(a_n=B\times p^b\)。列出同余方程:
发现必须有 \(a+b=c\),否则类似于 \(Ap^a\equiv Bp^b\),容易验证不可能存在。
同时除去 \(p^c\) 仍成立,此时存在 \(A\) 的逆元,所以模 \(p^{k-c}\) 意义下 \(B\) 唯一确定,记为 \(x\)。那么有:
已知 \(B\in [0,p^{k-b})\),所以模 \(p^k\) 意义下 \(B\) 总共有 \(\frac {p^{k-b}}{p^{k-c}}=p^a\) 种取值。
现在只需统计 \(\prod\limits_{i\in [1,len)}a_i=A\times p^a\) 的方案。dp 即可,记 \(f_{i,j}\) 表示前 \(i\) 个数有 \(j\) 个 \(p\),特殊地次数 \(> k\) 也算进 \(f_{i,k}\) 里。
转移枚举当前项 \(p\) 的次数即可,注意填 \(0\) 视作 \(p^k\)。容易矩阵乘法优化。
根据先前讨论 \(ans=\sum f_{len,i}p^i\),特判 \(n=0\) 时 \(ans=f_{len,k}\)。
复杂度瓶颈在于分解质因数,\(O(\sqrt m)\)。
代码
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define ADD(a, b) a = (a + b) % mod
const int N = 1e6 + 5, M = 50, mod = 998244353;
int k, n, m, ans, cnt[M], pw[M];
struct Matrix{
int a[M][M];
inline Matrix () {memset(a, 0, sizeof a);}
};
inline Matrix operator * (const Matrix &A, const Matrix &B){
Matrix C;
for(int i = 0; i < M; ++i)
for(int k = 0; k < M; ++k)
for(int j = 0; j < M; ++j)
ADD(C.a[i][j], A.a[i][k] * B.a[k][j] % mod);
return C;
}
inline Matrix qstp(Matrix A, int k){
Matrix res;
for(int i = 0; i < M; ++i) res.a[i][i] = 1;
for(; k; A = A * A, k >>= 1) if(k & 1) res = res * A;
return res;
}
inline void calc(int p, int c){
pw[0] = 1;
for(int i = 1; i <= c; ++i) pw[i] = pw[i - 1] * p;
for(int i = 0; i < c; ++i) cnt[i] = (pw[c] - 1) / pw[i] - (pw[c] - 1) / pw[i + 1];
cnt[c] = 1;
Matrix A, B;
A.a[0][0] = 1;
for(int i = 0; i <= c; ++i)
for(int j = 0; j <= c; ++j)
ADD(B.a[min(c, i + j)][i], cnt[j]);
A = qstp(B, k - 1) * A;
int nc = 0, nn = n % pw[c], res = 0;
if(nn == 0) A = B * A, res = A.a[c][0] % mod;
else{
while(nn % p == 0) ++nc, nn /= p;
for(int i = 0; i <= nc; ++i)
ADD(res, A.a[i][0] * pw[i] % mod);
}
ans = ans * res % mod;
}
inline void solve(){
for(int i = 2; i * i <= m; ++i)
if(m % i == 0){
int mc = 0;
while(m % i == 0) m /= i, ++mc;
calc(i, mc);
}
if(m > 1) calc(m, 1);
}
signed main(){
cin >> k >> n >> m;
ans = 1, solve();
cout << ans;
return 0;
}

浙公网安备 33010602011771号