中国剩余定理(CRT)学习笔记

定义

中国剩余定理(Chinese Remainder Theorem,CRT) 可求解如下形式的一元线性同余方程组(其中 \(n_1,n_2,\cdots,n_k\) 两两互质):

\[\begin{cases} x \equiv a_1 \pmod {n_1} \\ x \equiv a_2 \pmod {n_2} \\ \vdots \\ x \equiv a_1 \pmod {n_k} \end{cases}\]

过程

1.计算所有模数的积 \(n\)

2.对于第 \(i\) 的方程:

  • a. 计算 \(m_i=\frac{n}{n_i}\)

  • b. 计算 \(m_i\) 在模 \(n_i\) 意义下的逆元 \(m_i^{-1}\)

  • c.计算 \(c_i=m_i m_i^{-1}\)不要对 \(n_i\) 取模)。

3.方程组在模 \(n\) 意义下的唯一解为 \(x=\sum_{i=1}^{k} a_ic_i \pmod n\)

模板题code:

#include<iostream>
using namespace std;
#define int long long
const int N=21;
int x,y,n,a[N],m[N],ans,M=1;
int mul(int a,int b,int p)
{
	int res=1;
	while(b)
	{
		if(b&1) res=(res*a)%p;
		a=(a*a)%p;
		b>>=1;
	}
	return res;
}
int exgcd(int a,int b,int &x,int &y)
{
	if(b==0)
	{
		x=1,y=0;
		return a;
    }
    int d=exgcd(b,a%b,x,y);
    int t=x;
    x=y,y=t-y*(a/b);
    return d;
}
signed main()
{
	scanf("%lld",&n);
	for(int i=1;i<=n;i++) scanf("%lld%lld",&m[i],&a[i]),M*=m[i];
	for(int i=1;i<=n;i++)
	{
		x=0,y=0;
		exgcd(M/m[i],m[i],x,y);
		if(x<0) x+=m[i];
		ans=ans+a[i]*(M/m[i])*x;
	}
	printf("%lld\n",ans%M);
	return 0;
}

扩展:模数不互质的情况

这种情况下,可以假设已经求出前 \(k-1\) 个方程组的解为 \(X\), 且有 \(M=\prod_{i=1}^{k-1} n_i\)。(为了防止溢出,一般可以取 \(M=\mathrm{lcm}_{i=1}^{j-1} n_i\),可以证明这样做与直接相乘等价)。

那么前 \(k-1\) 个方程组的通解就是 \(X+i \times M(i \in \mathbb{Z})\)

现在加入了第 \(k\) 个方程,那么就是要求出方程 \(X+t \times M \equiv a_k \pmod {n_k}\)\(t\) 的解。可以用到扩展欧几里得定理求解。

于是,前 \(k\) 个方程的通解就是 \(X_k=X+t \times M\)

模板题

#include<iostream>
using namespace std;
#define int long long
int n,M=1,x,y,a,b,ans;
int read()
{
    int f=1,x=0;
    char ss=getchar();
    while(ss<'0'||ss>'9'){if(ss=='-')f=-1;ss=getchar();}
    while(ss>='0'&&ss<='9'){x=x*10+ss-'0';ss=getchar();}
    return f*x;
}
int exgcd(int &x,int &y,int a,int b)
{
	if(b==0)
	{
		x=1,y=0;
		return a;
	}
	int d=exgcd(x,y,b,a%b);
	int t=x;
	x=y,y=t-a/b*y;
	return d;
}
int mul(int a,int b,int mod)
{
	int res=0;
	while(b)
	{
		if(b&1) res=(res+a)%mod;
		a=(a+a)%mod;
		b>>=1;
	}
	return res;
}
signed main()
{
	n=read();
	M=read(),ans=read();
	for(int i=2;i<=n;i++)
	{
		b=read(),a=read();
		int c=(a-ans%b+b)%b;
		int d=exgcd(x,y,M,b),k=b/d;
		x=mul(x,c/d,k);
		ans+=x*M;
		M*=k;
		ans=(ans%M+M)%M;
	}
	printf("%lld\n",(ans%M+M)%M);
	return 0;
}
posted @ 2023-03-16 21:09  曙诚  阅读(85)  评论(0)    收藏  举报