中国剩余定理(CRT)学习笔记
定义
中国剩余定理(Chinese Remainder Theorem,CRT) 可求解如下形式的一元线性同余方程组(其中 \(n_1,n_2,\cdots,n_k\) 两两互质):
\[\begin{cases}
x \equiv a_1 \pmod {n_1} \\
x \equiv a_2 \pmod {n_2} \\
\vdots \\
x \equiv a_1 \pmod {n_k}
\end{cases}\]
过程
1.计算所有模数的积 \(n\);
2.对于第 \(i\) 的方程:
-
a. 计算 \(m_i=\frac{n}{n_i}\);
-
b. 计算 \(m_i\) 在模 \(n_i\) 意义下的逆元 \(m_i^{-1}\);
-
c.计算 \(c_i=m_i m_i^{-1}\) (不要对 \(n_i\) 取模)。
3.方程组在模 \(n\) 意义下的唯一解为 \(x=\sum_{i=1}^{k} a_ic_i \pmod n\)
模板题code:
#include<iostream>
using namespace std;
#define int long long
const int N=21;
int x,y,n,a[N],m[N],ans,M=1;
int mul(int a,int b,int p)
{
int res=1;
while(b)
{
if(b&1) res=(res*a)%p;
a=(a*a)%p;
b>>=1;
}
return res;
}
int exgcd(int a,int b,int &x,int &y)
{
if(b==0)
{
x=1,y=0;
return a;
}
int d=exgcd(b,a%b,x,y);
int t=x;
x=y,y=t-y*(a/b);
return d;
}
signed main()
{
scanf("%lld",&n);
for(int i=1;i<=n;i++) scanf("%lld%lld",&m[i],&a[i]),M*=m[i];
for(int i=1;i<=n;i++)
{
x=0,y=0;
exgcd(M/m[i],m[i],x,y);
if(x<0) x+=m[i];
ans=ans+a[i]*(M/m[i])*x;
}
printf("%lld\n",ans%M);
return 0;
}
扩展:模数不互质的情况
这种情况下,可以假设已经求出前 \(k-1\) 个方程组的解为 \(X\), 且有 \(M=\prod_{i=1}^{k-1} n_i\)。(为了防止溢出,一般可以取 \(M=\mathrm{lcm}_{i=1}^{j-1} n_i\),可以证明这样做与直接相乘等价)。
那么前 \(k-1\) 个方程组的通解就是 \(X+i \times M(i \in \mathbb{Z})\)。
现在加入了第 \(k\) 个方程,那么就是要求出方程 \(X+t \times M \equiv a_k \pmod {n_k}\) 中 \(t\) 的解。可以用到扩展欧几里得定理求解。
于是,前 \(k\) 个方程的通解就是 \(X_k=X+t \times M\)。
模板题:
#include<iostream>
using namespace std;
#define int long long
int n,M=1,x,y,a,b,ans;
int read()
{
int f=1,x=0;
char ss=getchar();
while(ss<'0'||ss>'9'){if(ss=='-')f=-1;ss=getchar();}
while(ss>='0'&&ss<='9'){x=x*10+ss-'0';ss=getchar();}
return f*x;
}
int exgcd(int &x,int &y,int a,int b)
{
if(b==0)
{
x=1,y=0;
return a;
}
int d=exgcd(x,y,b,a%b);
int t=x;
x=y,y=t-a/b*y;
return d;
}
int mul(int a,int b,int mod)
{
int res=0;
while(b)
{
if(b&1) res=(res+a)%mod;
a=(a+a)%mod;
b>>=1;
}
return res;
}
signed main()
{
n=read();
M=read(),ans=read();
for(int i=2;i<=n;i++)
{
b=read(),a=read();
int c=(a-ans%b+b)%b;
int d=exgcd(x,y,M,b),k=b/d;
x=mul(x,c/d,k);
ans+=x*M;
M*=k;
ans=(ans%M+M)%M;
}
printf("%lld\n",(ans%M+M)%M);
return 0;
}