高精度整数类 v1.7
https://github.com/headless-piston/BigInteger
(\(O(n\log n)\) 的高精度除法简直就是稀屎,可能要无限期咕咕咕了)
背景
我们痛恨高精度。
介绍
不想写高精度怎么办?提前写好模板,要用时直接复制粘贴就好了。
BigInteger
#include<vector>
#include<iostream>
#include<iomanip>
#include<algorithm>
#include<cmath>
#include<stdexcept>
namespace __FFT{
constexpr double PI2=6.283185307179586231995927;
struct complex{
double real,imag;
complex operator+(const complex &x)const{
return {real+x.real,imag+x.imag};
}
complex operator-(const complex &x)const{
return {real-x.real,imag-x.imag};
}
complex operator*(const complex &x)const{
return {real*x.real-imag*x.imag,real*x.imag+x.real*imag};
}
};
std::vector<complex> omega;
void init_omega(const int &n){
if(n<int(omega.size())) return;
int start=omega.empty()?1:omega.size()<<1;
omega.resize(1<<std::__lg(n));
for(int i=start;i<=n;i<<=1)
for(int j=0;j<(i>>1);j++){
double arg=PI2*j/i;
omega[(i>>1)+j]={cos(arg),sin(arg)};
}
}
void FFT(std::vector<complex> &a,int n,bool inv){
for(int i=0,j=0;i<n;i++){
if(i<j) std::swap(a[i],a[j]);
for(int l=n>>1;(j^=l)<l;l>>=1);
}
for(int len=2;len<=n;len<<=1)
for(int i=0;i<n;i+=len)
for(int j=0;j<(len>>1);j++){
complex w=inv?complex({omega[(len>>1)+j].real,-omega[(len>>1)+j].imag}):omega[(len>>1)+j];
complex x=a[i+j],y=a[i+j+(len>>1)]*w;
a[i+j]=x+y,a[i+j+(len>>1)]=x-y;
}
if(inv) for(int i=0;i<n;i++) a[i].real/=n;
}
}
constexpr int BASE=10000;
constexpr int DIGITS_PER_BASE=4;
class bigint{
private:
std::vector<int> num;
bool is_negative;
static int cmp_abs(const bigint &a,const bigint &b){
if(a.num.size()!=b.num.size())
return a.num.size()<b.num.size()?-1:1;
for(int i=int(a.num.size())-1;i>=0;i--)
if(a.num[i]!=b.num[i])
return a.num[i]<b.num[i]?-1:1;
return 0;
}
bigint left_shift(const int &k){
if(num.size()==1&&num[0]==0) return *this;
num.insert(num.begin(),k,0);
return *this;
}
bigint right_shift(const int &k){
if(k>=int(num.size())) return *this;
num.erase(num.begin(),num.begin()+k);
return *this;
}
std::pair<bigint,bigint> div_mod(const bigint &x)const{
if(x==0) throw std::invalid_argument("Division by zero!");
if(cmp_abs(*this,x)<0) return {0,*this};
bigint quo,rem=this->abs();
quo.num.resize(num.size()-x.num.size()+1);
for(int i=int(num.size())-int(x.num.size());i>=0;i--){
int low=0,high=BASE-1,res=0;
while(low<=high){
int mid=(low+high)/2;
bigint prod=x*mid;
prod.left_shift(i);
if(cmp_abs(prod,rem)<=0){
res=mid;
low=mid+1;
}
else high=mid-1;
}
quo.num[i]=res;
if(res!=0){
bigint prod=x.abs()*res;
prod.left_shift(i);
rem=rem-prod;
}
}
while(quo.num.size()>1&&quo.num.back()==0) quo.num.pop_back();
quo.is_negative=is_negative!=x.is_negative;
if(quo.num.size()==1&&quo.num[0]==0) quo.is_negative=false;
rem.is_negative=is_negative;
if(rem.num.size()==1&&rem.num[0]==0) rem.is_negative=false;
return {quo,rem};
}
public:
bigint():is_negative(false){num.push_back(0);}
friend std::istream &operator>>(std::istream &in,bigint &a){
std::string s;
in>>s;a=bigint(s);
return in;
}
friend std::ostream &operator<<(std::ostream &out,const bigint &a){
if(a.is_negative) out<<'-';
out<<a.num.back();
for(int i=int(a.num.size())-2;i>=0;i--)
out<<std::setw(DIGITS_PER_BASE)<<std::setfill('0')<<a.num[i];
return out;
}
bool operator<(const bigint &x)const{
if(is_negative!=x.is_negative)
return is_negative>x.is_negative;
if(num.size()!=x.num.size())
return is_negative?num.size()>x.num.size():num.size()<x.num.size();
for(int i=int(num.size())-1;i>=0;i--)
if(num[i]!=x.num[i])
return is_negative?num[i]>x.num[i]:num[i]<x.num[i];
return false;
}
bool operator>(const bigint &x)const{return x<*this;}
bool operator<=(const bigint &x)const{return !(*this>x);}
bool operator>=(const bigint &x)const{return !(*this<x);}
bool operator==(const bigint &x)const{
if(is_negative!=x.is_negative) return false;
if(num.size()!=x.num.size()) return false;
for(int i=0;i<int(num.size());i++)
if(num[i]!=x.num[i]) return false;
return true;
}
bool operator!=(const bigint &x)const{return !(*this==x);}
bigint abs()const{
bigint res=*this;res.is_negative=false;
return res;
}
bigint(long long x){
num.clear();
if(x<0) is_negative=true,x=-x;
else is_negative=false;
if(x==0) num.push_back(0);
while(x){
num.push_back(x%BASE);
x/=BASE;
}
}
bigint(const std::string &s){
if(!s.length())
throw std::invalid_argument("Error:An invalid number!");
num.clear(),is_negative=false;
int low=0;
if(s[0]=='-') low=1,is_negative=true;
int base_num=0,base_w=1;
for(int i=int(s.length())-1;i>=low;i--){
if(s[i]<'0'||s[i]>'9')
throw std::invalid_argument("Error:An invalid number!");
base_num+=(s[i]^48)*base_w;
base_w*=10;
if(base_w==BASE||i==low){
num.push_back(base_num);
base_num=0,base_w=1;
}
}
if(!num.size())
throw std::invalid_argument("Error:An invalid number!");
if(num.size()==1&&num.back()==0&&is_negative)
throw std::invalid_argument("Error:An invalid number!");
}
bigint operator-()const{
bigint res=*this;
if(res.num.size()==1&&res.num[0]==0)
res.is_negative=false;
else res.is_negative=!is_negative;
return res;
}
bigint operator+(const bigint &x)const{
bigint res;
if(is_negative==x.is_negative){
res.is_negative=is_negative;
res.num.resize(std::max(num.size(),x.num.size()));
int carry=0;
for(int i=0;i<int(res.num.size());i++){
int sum=carry;
if(i<int(num.size())) sum+=num[i];
if(i<int(x.num.size())) sum+=x.num[i];
res.num[i]=sum%BASE;
carry=sum/BASE;
}
if(carry) res.num.push_back(carry);
}
else{
int cmp=cmp_abs(*this,x);
const bigint &larger=cmp>=0?*this:x;
const bigint &smaller=cmp>=0?x:*this;
res.is_negative=cmp>=0?is_negative:x.is_negative;
res.num.resize(larger.num.size());
int borrow=0;
for(int i=0;i<int(res.num.size());i++){
int diff=larger.num[i]-borrow;
if(i<int(smaller.num.size())) diff-=smaller.num[i];
if(diff<0){
diff+=BASE;
borrow=1;
}
else borrow=0;
res.num[i]=diff;
}
while(res.num.size()>1&&res.num.back()==0) res.num.pop_back();
if(res.num.size()==1&&res.num[0]==0)
res.is_negative=0;
}
return res;
}
bigint operator-(const bigint &x)const{
bigint temp=x;
temp.is_negative=!temp.is_negative;
return *this+temp;
}
bigint operator*(const bigint &x)const{
bigint res;
res.is_negative=(is_negative!=x.is_negative);
int len=1;while(len<int(num.size()+x.num.size())) len<<=1;
std::vector<__FFT::complex> fa(len),fb(len);
for(int i=0;i<int(num.size());i++)
fa[i]={double(num[i]),0};
for(int i=0;i<int(x.num.size());i++)
fb[i]={double(x.num[i]),0};
__FFT::init_omega(len);
__FFT::FFT(fa,len,false);
__FFT::FFT(fb,len,false);
for(int i=0;i<len;i++)
fa[i]=fa[i]*fb[i];
__FFT::FFT(fa,len,true);
res.num.resize(len+1);
int carry=0;
for(int i=0;i<len;i++){
long long val=round(fa[i].real)+carry;
res.num[i]=val%BASE;
carry=val/BASE;
}
if(carry) res.num[len]=carry;
while(res.num.size()>1&&res.num.back()==0) res.num.pop_back();
if(res.num.size()==1&&res.num[0]==0)
res.is_negative=false;
return res;
}
bigint operator/(const bigint &x)const{return this->div_mod(x).first;}
bigint operator%(const bigint &x)const{return this->div_mod(x).second;}
};
使用说明
前言
这份高精度模板使用压位实现常数优化,实现过程中为了保证乘法运算的精度,最终选择了压 \(4\) 位,可以保证 \(2\times 10^{1000000}\) 以内的精度。
声明
bigint a;//动态内存
I/O 方式
为方便使用,接入了 iostream 的 I/O。
bigint a;
std::cin>>a;
std::cout<<a;
数值运算符
bigint a,b;
a+b;
a-b;
a*b;
a/b;
a%b;
除法/模运算的取整方式/符号与 C++ 标准中对整形运算的规定相同,即除法向零取整,余数的符号与被除数相同。
赋值运算符
bigint a;
a=-1919810;
关系符
除三路比较运算符(C++20)外的所有常用大小关系符。
bigint a=10,b=20;
a<b;//此表达式为true
a!=b;//此表达式为true
a==b;//此表达式为false
其他
成员函数 abs(),以 bigint 类型返回该数的绝对值。
私有函数 left_shift(const int &k) 和 right_shift(const int &k),用于在十进制下对数进行左/右移,等价于乘 \(10000^k\) 或除 \(10000^k\) 并向零取整,时间复杂度 \(O(n)\)。
Upd
v1.1:增加了 NTT 命名空间,但因为一些原因暂不使用。
v1.2:改了改码风。
v1.3:使用 C++ 标准库中的 std::complex 代替手写的复数类。修复了初始化错误的问题。
v1.4:std::complex 跑的太慢,换回手写的复数类。
v1.5:改用预处理单位根计算,提高了精度。现在可以压 \(4\) 位啦,效率大提升!
v1.6:整体重构代码,现在支持动态内存。同时大幅提升了安全性和性能,增加了非法输入的检查,高精度乘法在洛谷模板题成功进入 \(1\operatorname{s}\),现在处于最优解第 5 页。
v1.7:完善了 v1.6 的代码,修复了前导零删除的问题。由于时间和精力有限,仅实现了 \(O(n^2\log n)\) 的除法和取模。

浙公网安备 33010602011771号