【UR #5】怎样跑得更快
给定整数 \(c\) 和 \(d\) 和质数 \(p=998244353\)。有 \(q\) 次询问,每次询问给定长度为 \(n\) 的序列 \(b\),解方程组:
\(\forall i \in [1,n],\sum\limits_{j=1}^{n} \gcd(i,j)^c \times \operatorname{lcm}(i,j)^d \times x_j \equiv b_i \pmod p\)
你需要保证 \(0 \leq x_i < p\)。若有多组解输出任意一组,若无解输出 \(-1\)。
\(0 \leq b_i < p\),\(n \leq 10^5\),\(nq \leq 3 \times 10^5\),\(0 \leq c,d \leq 10^9\)。
为了方便,接下来使用 \(x=y\) 代替 \(x \equiv y \pmod p\)。
考虑 \(\operatorname{lcm}(i,j)=\dfrac{ij}{\gcd(i,j)}\),则原式可化为 \(\sum\limits_{j=1}^{n} \gcd(i,j)^{c-d} \times i^d \times j^d \times x_j = b_i\)。
进行反演。首先令函数 \(f(x)=x^{c-d}\),且函数 \(g(x)\) 满足 \(f(x)=\sum\limits_{p \mid x} g(p)\)。容易发现我们可以在 \(O(n \log n)\) 的时间内求出函数 \(f(x)\) 和 \(g(x)\)。此时我们可以将原式化为 \(\sum\limits_{j=1}^{n} \sum\limits_{p \mid i,p \mid j} g(p) \times i^d \times j^d \times x_j = b_i\),稍加变形可得 \(\sum\limits_{p \mid i} g(p)\sum\limits_{p \mid j} j^d \times x_j = \dfrac{b_i}{i^d}\)。
考虑 \(\sum\limits_{p \mid j} j^d \times x_j\) 的部分,发现其实际意义为所有 \(p\) 的倍数 \(j\) 对应的 \(j^d \times x_j\) 之和。不妨令其为 \(h(p)\)。此时有 \(\sum\limits_{p \mid i} g(p)h(p)= \dfrac{b_i}{i^d}\) 成立。容易发现此时我们可以再次反演,对于每个 \(p\) 求出 \(g(p)h(p)\)。这一步的复杂度也是 \(O(n \log n)\)。我们已经求出了 \(g(p)\),则我们可以对于每个 \(p\) 算出 \(h(p)=\sum\limits_{p \mid j} j^d \times x_j\)。此时我们再次反演,即可对每个 \(j\) 求出 \(j^d \times x_j\)。接下来就可以算出 \(x_j\) 了。
当然,真的需要反演吗?我们可以直接使用递推方式代替反演。
#include<iostream>
#include<cstdio>
using namespace std;
const long long mod=998244353,mod_pow=998244352;
long long pow_mod(long long num1,long long num2){
num2=(num2%mod_pow+mod_pow)%mod_pow;
long long num3=1;
while(num2){
if(num2&1) num3=num3*num1%mod;
num1=num1*num1%mod;
num2>>=1;
}
return num3;
}
int n,q;
long long c,d,b[100010],f[100010],g[100010],h[100010];
void solve(){
for(int i=1;i<=n;i++){
scanf("%lld",&b[i]);
b[i]=b[i]*pow_mod(i,-d)%mod;
}
for(int i=1;i<=n;i++){
for(int j=2;j<=n/i;j++){
b[i*j]-=b[i];
b[i*j]=(b[i*j]%mod+mod)%mod;
}
}
for(int i=1;i<=n;i++){
if(b[i]!=0 && g[i]==0){
printf("-1\n");
return ;
}
h[i]=b[i]*pow_mod(g[i],-1)%mod;
}
for(int i=n;i>=1;i--){
for(int j=2;j<=n/i;j++){
h[i]-=h[i*j];
h[i]=(h[i]%mod+mod)%mod;
}
}
for(int i=1;i<=n;i++){
printf("%lld ",h[i]*pow_mod(i,-d)%mod);
}
printf("\n");
}
int main(){
scanf("%d %lld %lld %d",&n,&c,&d,&q);
for(int i=1;i<=n;i++){
g[i]=f[i]=pow_mod(i,c-d);
}
for(int i=1;i<=n;i++){
for(int j=2;j<=n/i;j++){
g[i*j]-=g[i];
g[i*j]=(g[i*j]%mod+mod)%mod;
}
}
while(q--){
solve();
}
return 0;
}

浙公网安备 33010602011771号