返回顶部

模板 - 扩展中国剩余定理

解k个线性同余方程构成的线性同余方程组,每个方程形如: \(x_i = c_i \space mod \space m_i\) ,假如有解输出最小非负整数解,否则输出-1。

一个更不容易溢出的版本的扩展中国剩余定理:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef long long lll;

lll read() {
    lll f = 1, x = 0;
    char ch = getchar();
    while(ch < '0' || ch > '9') {
        if(ch == '-')
            f = -1;
        ch = getchar();
    }
    while(ch >= '0' && ch <= '9') {
        x = (x << 3) + (x << 1) + ch - '0';
        ch = getchar();
    }
    return f * x;
}

const int MAXK = 100000 + 5;
lll ci[MAXK], mi[MAXK];

lll mul(lll a, lll b, lll mod) {
    lll res = 0;
    while(b > 0) {
        if(b & 1)
            res = (res + a) % mod;
        a = (a + a) % mod;
        b >>= 1;
    }
    return res;
}

lll exgcd(lll a, lll b, lll &x, lll &y) {
    if(b == 0) {
        x = 1, y = 0;
        return a;
    }
    lll gcd = exgcd(b, a % b, x, y);
    lll tp = x;
    x = y;
    y = tp - a / b * y;
    return gcd;
}

lll exCRT(int K) {
    lll x, y;
    lll M = mi[1], ans = ci[1];
    for(int i = 2; i <= K; i++) {
        lll a = M, b = mi[i], c = (ci[i] - ans % b + b) % b;
        lll gcd = exgcd(a, b, x, y), bg = b / gcd;
        if(c % gcd != 0)
            return -1;
        x = mul(x, c / gcd, bg);
        ans += x * M;
        M *= bg;
        ans = (ans % M + M) % M;
    }
    return ans;
}

int main() {
    int K = read();
    for(int i = 1; i <= K; ++i)
        mi[i] = read(), ci[i] = read();
    printf("%lld", (ll)exCRT(K));
    return 0;
}

下面是一个简短版本的扩展中国剩余定理,不对同余方程进行分解,但是增加了溢出的可能性。

可能在模数比较大的时候会溢出的版本,方法是直接对两个方程强行合并。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;

const int MAXK = 100005;

void exgcd(ll a, ll b, ll &x, ll &y) {
    if(!b)
        x = 1, y = 0;
    else
        exgcd(b, a % b, y, x), y -= a / b * x;
}

ll inv(ll a, ll b) {
    ll x = 0, y = 0;
    exgcd(a, b, x, y);
    x = (x % b + b) % b;
    if(!x)
        x += b;
    return x;
}

int k;
ll c[MAXK], m[MAXK];

//解k个线性同余方程构成的方程组, xi = ci mod mi ,假如有解,返回最小非负整数解,否则返回-1
ll exCRT(int k) {
    ll c1, c2, m1, m2, t;
    for(int i = 2; i <= k; ++i) {
        m1 = m[i - 1], m2 = m[i], c1 = c[i - 1], c2 = c[i];
        t = __gcd(m1, m2);
        if((c2 - c1) % t != 0)
            return -1;
        m[i] = m1 * m2 / t;
        c[i] = inv(m1 / t, m2 / t) * ((c2 - c1) / t) % (m2 / t) * m1 + c1;
        c[i] = (c[i] % m[i] + m[i]) % m[i];
    }
    return c[k];
}

//解k个线性同余方程构成的方程组, xi = ci mod mi ,假如有解,返回最小非负整数解,否则返回-1
int main() {
#ifdef Inko
    freopen("Inko.in", "r", stdin);
#endif // Inko
    while(~scanf("%d", &k)) {
        for(int i = 1; i <= k; ++i)
            scanf("%lld%lld", &m[i], &c[i]);
        printf("%lld\n", exCRT(k));
    }
}

有时候题目的确会溢出,这个时候要选择带有大数取模的Java(注意BigInteger是传值而不是传引用的)。但有时候可以把每个同余方程分解成几个小的。

java版本的,使用long,效率还算可以。

package acscut;

import java.io.*;
import java.math.*;
import java.util.*;

public class Main {
    public static void main(String[] args) {
    	Scanner sc=new Scanner(System.in);
    	EXCRT excrt=new EXCRT();
    	excrt.k=sc.nextInt();
    	for(int i=1;i<=excrt.k;++i) {
    		excrt.m[i]=sc.nextLong();
    		excrt.c[i]=sc.nextLong();
    	}
    	System.out.println(excrt.exCRT());
    }
}

class EXCRT{
	long gcd(long a,long b) {
		if(b==0)
			return a;
		else
			return gcd(b,a%b);
	}
	
	void exgcd(long a,long b,long x[],long y[]) {
		if(b==0) {
			x[0]=1;
			y[0]=0;
		}
		else {
			exgcd(b,a%b,y,x);
			y[0]-=a/b*x[0];
		}
	}
	
	long inv(long a,long b) {
		long x[]=new long [1];
		long y[]=new long [1];
		x[0]=0;
		y[0]=0;
		exgcd(a,b,x,y);
		x[0]=(x[0]%b+b)%b;
		if(x[0]==0)
			x[0]+=b;
		return x[0];
	}
	
	int MAXK=100005;
	
	long c[]=new long [MAXK];
	long m[]=new long [MAXK];
	
	int k;
	
	//解k个线性同余方程构成的方程组, xi = ci mod mi ,假如有解,返回最小非负整数解,否则返回-1
    long exCRT() {
    	long c1,c2,m1,m2,t;
    	for(int i=2;i<=k;++i) {
    		m1=m[i-1];
    		m2=m[i];
    		c1=c[i-1];
    		c2=c[i];
    		t=gcd(m1,m2);
    		if((c2-c1)%t!=0) {
    			return -1;
    		}
    		m[i]=m1*m2/t;
    		c[i]=inv(m1/t,m2/t)*((c2-c1)/t)%(m2/t)*m1+c1;
    		c[i]=(c[i]%m[i]+m[i])%m[i];
    	}
    	return c[k];
    }
}

class Scanner {
    private BufferedReader br;
    private StringTokenizer st;
 
    Scanner(InputStream in) {
        br = new BufferedReader(new InputStreamReader(in));
        st = new StringTokenizer("");
    }
 
    String nextLine() {
        try {
            return br.readLine();
        } catch (IOException e) {
            throw new IOError(e);
        }
    }
 
    boolean hasNext() {
        while (!st.hasMoreTokens()) {
            String s = nextLine();
            if (s == null) {
                return false;
            }
            st = new StringTokenizer(s);
        }
        return true;
    }
 
    String next() {
        hasNext();
        return st.nextToken();
    }
 
    int nextInt() {
        return Integer.parseInt(next());
    }
 
    long nextLong() {
        return Long.parseLong(next());
    }
 
    double nextDouble() {
        return Double.parseDouble(next());
    }
}

使用BigInteger还是很有问题的,太慢了。

package acscut;

import java.io.*;
import java.math.*;
import java.util.*;

public class Main {
    public static void main(String[] args) {
    	Scanner sc=new Scanner(System.in);
    	EXCRT excrt=new EXCRT();
    	excrt.k=sc.nextInt();
    	for(int i=1;i<=excrt.k;++i) {
    		long tmp=sc.nextLong();
    		excrt.m[i]=BigInteger.valueOf(tmp);
    		tmp=sc.nextLong();
    		excrt.c[i]=BigInteger.valueOf(tmp);
    	}
    	System.out.println(excrt.exCRT());
    }
}

class EXCRT{
	BigInteger gcd(BigInteger a,BigInteger b) {
		if(b.compareTo(BigInteger.ZERO)==0)
			return a;
		else
			return gcd(b,a.mod(b));
	}
	
	void exgcd(BigInteger a,BigInteger b,BigInteger x[],BigInteger y[]) {
		if(b.compareTo(BigInteger.ZERO)==0) {
			x[0]=BigInteger.ONE;
			y[0]=BigInteger.ZERO;
		}
		else {
			exgcd(b,a.mod(b),y,x);
			y[0]=y[0].subtract(a.divide(b).multiply(x[0]));
		}
	}
	
	BigInteger inv(BigInteger a,BigInteger b) {
		BigInteger x[]=new BigInteger [1];
		BigInteger y[]=new BigInteger [1];
		x[0]=BigInteger.ZERO;
		y[0]=BigInteger.ZERO;
		exgcd(a,b,x,y);
		x[0]=(x[0].mod(b).add(b)).mod(b);
		if(x[0].compareTo(BigInteger.ZERO)==0)
			x[0]=x[0].add(b);
		return x[0];
	}
	
	int MAXK=100005;
	
	BigInteger c[]=new BigInteger [MAXK];
	BigInteger m[]=new BigInteger [MAXK];
	
	int k;
	
	//解k个线性同余方程构成的方程组, xi = ci mod mi ,假如有解,返回最小非负整数解,否则返回-1
    long exCRT() {
    	BigInteger c1,c2,m1,m2,t;
    	for(int i=2;i<=k;++i) {
    		m1=m[i-1];
    		m2=m[i];
    		c1=c[i-1];
    		c2=c[i];
    		t=gcd(m1,m2);
    		if((c2.subtract(c1)).mod(t).compareTo(BigInteger.ZERO)!=0) {
    			return -1;
    		}
    		m[i]=m1.multiply(m2).divide(t);
    		c[i]=inv(m1.divide(t),m2.divide(t)).multiply(c2.subtract(c1).divide(t).mod(m2.divide(t))).multiply(m1).add(c1);
    		c[i]=(c[i].mod(m[i]).add(m[i])).mod(m[i]);
    	}
    	return c[k].longValue();
    }
}

class Scanner {
    private BufferedReader br;
    private StringTokenizer st;
 
    Scanner(InputStream in) {
        br = new BufferedReader(new InputStreamReader(in));
        st = new StringTokenizer("");
    }
 
    String nextLine() {
        try {
            return br.readLine();
        } catch (IOException e) {
            throw new IOError(e);
        }
    }
 
    boolean hasNext() {
        while (!st.hasMoreTokens()) {
            String s = nextLine();
            if (s == null) {
                return false;
            }
            st = new StringTokenizer(s);
        }
        return true;
    }
 
    String next() {
        hasNext();
        return st.nextToken();
    }
 
    int nextInt() {
        return Integer.parseInt(next());
    }
 
    long nextLong() {
        return Long.parseLong(next());
    }
 
    double nextDouble() {
        return Double.parseDouble(next());
    }
}

最后只能够用__int128来搞了。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef __int128 lll;

const int MAXK = 100005;

void exgcd(lll a, lll b, lll &x, lll &y) {
    if(!b)
        x = 1, y = 0;
    else
        exgcd(b, a % b, y, x), y -= a / b * x;
}

lll inv(lll a, lll b) {
    lll x = 0, y = 0;
    exgcd(a, b, x, y);
    x = (x % b + b) % b;
    if(!x)
        x += b;
    return x;
}

int k;
lll c[MAXK], m[MAXK];

//解k个线性同余方程构成的方程组, xi = ci mod mi ,假如有解,返回最小非负整数解,否则返回-1
lll exCRT(int k) {
    lll c1, c2, m1, m2, t;
    for(int i = 2; i <= k; ++i) {
        m1 = m[i - 1], m2 = m[i], c1 = c[i - 1], c2 = c[i];
        t = __gcd(m1, m2);
        if((c2 - c1) % t != 0)
            return -1;
        m[i] = m1 * m2 / t;
        c[i] = inv(m1 / t, m2 / t) * ((c2 - c1) / t) % (m2 / t) * m1 + c1;
        c[i] = (c[i] % m[i] + m[i]) % m[i];
    }
    return c[k];
}

//解k个线性同余方程构成的方程组, xi = ci mod mi ,假如有解,返回最小非负整数解,否则返回-1
int main() {
#ifdef Inko
    freopen("Inko.in", "r", stdin);
#endif // Inko
    while(~scanf("%d", &k)) {
        for(int i = 1; i <= k; ++i){
            ll tmp;
            scanf("%lld",&tmp);
            m[i]=tmp;
            scanf("%lld",&tmp);
            c[i]=tmp;
        }
        printf("%lld\n", exCRT(k));
    }
}
posted @ 2019-08-20 01:26  Inko  阅读(...)  评论(...编辑  收藏