/*
* $File: test.cpp
* $Date: Wed Feb 09 13:22:29 2011 +0800
* $Author: Zhou Xinyu <zxytim@gmail.com>
*
* a simple High precision integer implementation
*/
#include <cstdio>
#include <cstring>
#include <cctype>
#include <cmath>
#include <algorithm>
#include <cassert>
class Bignum
{
#define MULTIPLICATION_FASTER
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
private:
typedef int Num_t;
typedef long long Num_bigger_t;
static const int N_BIT_MAX = 25000;
static const int ITER_DEPTH_MAX = 14; // log(N_BIT_MAX) / log(2.0)
static const int BASE = 1000000000;
// static const int BASE = 10;
// static const int ITER_DEPTH_MAX = 100; // log(N_BIT_MAX) / log(2.0)
int nbits;
Num_t bit[N_BIT_MAX];
static const int POSITIVE = false,
NEGATIVE = true;
// when the number is 0, @sign is %POSITIVE
bool sign;
bool isZero() const
{
return nbits == 1 && bit[0] == 0;
}
void setZero()
{
nbits = 1;
bit[0] = 0;
sign = POSITIVE;
}
static bool absGreater(const Bignum &a, const Bignum &b)
{
if (a.nbits != b.nbits)
return a.nbits > b.nbits;
for (int i = a.nbits - 1; i >= 0; i --)
if (a.bit[i] != b.bit[i])
return a.bit[i] > b.bit[i];
return false;
}
static bool absGreaterEqual(const Bignum &a, const Bignum &b)
{
if (a.nbits != b.nbits)
return a.nbits > b.nbits;
for (int i = a.nbits - 1; i >= 0; i --)
if (a.bit[i] != b.bit[i])
return a.bit[i] > b.bit[i];
return true;
}
static void absPlus(const Bignum &a, const Bignum &b, Bignum &ret)
{
if (a.isZero())
{
ret = b;
return;
}
else if (b.isZero())
{
ret = a;
return;
}
int &len = ret.nbits = MAX(a.nbits, b.nbits);
ret.bit[0] = 0;
for (int i = 0; i < len; i ++)
{
ret.bit[i] += (i < a.nbits ? a.bit[i] : 0)
+ (i < b.nbits ? b.bit[i] : 0);
if (ret.bit[i] >= BASE)
{
ret.bit[i] -= BASE;
ret.bit[i + 1] = 1;
} else ret.bit[i + 1] = 0;
}
if (ret.bit[len])
len ++;
}
static void absPlusSafe(const Bignum &a, const Bignum &b, Bignum &c)
{
if (a.isZero())
{
c = b;
return;
}
else if (b.isZero())
{
c = a;
return;
}
static Bignum ret;
int &len = ret.nbits = MAX(a.nbits, b.nbits);
ret.bit[0] = 0;
for (int i = 0; i < len; i ++)
{
ret.bit[i] += (i < a.nbits ? a.bit[i] : 0)
+ (i < b.nbits ? b.bit[i] : 0);
if (ret.bit[i] >= BASE)
{
ret.bit[i] -= BASE;
ret.bit[i + 1] = 1;
} else ret.bit[i + 1] = 0;
}
if (ret.bit[len])
len ++;
c = ret;
}
static void absMinus(const Bignum &a, const Bignum &b, Bignum &ret)
{
#ifdef DEBUG
assert(a.nbits>= b.nbits);
assert(absGreaterEqual(a, b));
#endif
if (b.isZero())
{
ret = a;
return;
}
int &len = ret.nbits = a.nbits;
Num_t borrow = 0;
for (int i = 0; i < len; i ++)
{
Num_t n0 = a.bit[i] - borrow,
n1 = (i < b.nbits ? b.bit[i] : 0);
if (n0 < n1)
{
n0 += BASE;
borrow = 1;
} else borrow = 0;
ret.bit[i] = n0 - n1;
}
while (len > 1 && ret.bit[len - 1] == 0)
len --;
}
/*---------------- algorithms of big integer multiplication --------------- */
/*
* force
* O(n^2)
*/
static void absMultiply_square_n(const Bignum &a, const Bignum &b, Bignum &ret)
{
if (a.isZero() || b.isZero())
{
ret.setZero();
return;
}
int &len = ret.nbits = a.nbits + b.nbits;
for (int i = 0; i < len; i ++)
ret.bit[i] = 0;
for (int i = 0; i < a.nbits; i ++)
for (int j = 0, p = i + j; j < b.nbits; j ++, p ++)
{
Num_bigger_t now = a.bit[i];
now *= b.bit[j];
now += ret.bit[p];
if (now >= BASE)
{
Num_t v = now / BASE;
ret.bit[p + 1] += v;
ret.bit[p] = (Num_t)(now - v * BASE); // now % BASE
} else ret.bit[p] = (Num_t)now;
}
if (ret.bit[len - 1] == 0)
len --;
}
/*
* suppose the number of digits of the greater one of a and b is n, and we can
* fill the smaller one with leading zeros.
*
* the division below is integer division.
*
* let a = A*10^(n/2) + B
* b = C*10^(n/2) + D
*
* then
* a * b = (A*10^(n/2) + B) * (C*10^(n/2) + D)
* = AC*10^(n/2 + n/2) + BC*10^(n/2) + AD*10^(n/2) + BD
*
* a * b = AC*10^(n/2 + n/2) + (BC + AD)*10^(n/2) + BD
* = AC*10^(n/2 + n/2) + (BC + AD - AC - BD + AC + BD)*10^(n/2) + BD
* = AC*10^(n/2 + n/2) + ((A - B) * (D - C) + AC + BD)*10^(n/2) + BD
*
* we can see that only three products AC, BD and (A - B)(D - C) need to be calculated.
* suppose the number of digits of a and b are all equal to n,
* then the time complexity is:
* T(1) = 1
* T(n) = 3T(n/2) + O(n)
* and the solution is:
* T(n) = O(n^(log(3, 2))) = O(n^1.58496)
*/
static void absMultiply_n_power_1p58496_iter(const Bignum &a, const Bignum &b, Bignum &ret, int depth)
{
static Bignum A[ITER_DEPTH_MAX], B[ITER_DEPTH_MAX], C[ITER_DEPTH_MAX], D[ITER_DEPTH_MAX];
// two below are product
static Bignum AC[ITER_DEPTH_MAX], BD[ITER_DEPTH_MAX];
// two below are difference
static Bignum AB[ITER_DEPTH_MAX], DC[ITER_DEPTH_MAX];
// below is product of differences
static Bignum ABDC[ITER_DEPTH_MAX];
int nbits = MAX(a.nbits, b.nbits),
half_nbits = nbits - (nbits >> 1);
ret.setZero();
if (a.isZero() || b.isZero())
return;
// just a small useless trick
static const int NBITS_TO_FORCE = 20;
int min_nbits = MIN(a.nbits, b.nbits);
if (min_nbits <= NBITS_TO_FORCE || min_nbits <= sqrt((double)nbits))
{
absMultiply_square_n(a, b, ret);
return;
}
const Bignum *pa = &a, *pb = &b;
if (absGreater(a, b))
swap(pa, pb);
// *pb is the greater one
partition(*pa, A[depth], B[depth], nbits);
partition(*pb, C[depth], D[depth], nbits);
absMultiply_n_power_1p58496_iter(A[depth], C[depth], AC[depth], depth + 1);
#ifdef DEBUG
static Bignum tmp;
absMultiply_square_n(A[depth], C[depth], tmp);
assert(tmp == AC[depth]);
#endif
absMultiply_n_power_1p58496_iter(B[depth], D[depth], BD[depth], depth + 1);
#ifdef DEBUG
absMultiply_square_n(B[depth], D[depth], tmp);
assert(tmp == BD[depth]);
#endif
absPlus(AC[depth], BD[depth], ret);
AB[depth] = A[depth] - B[depth];
DC[depth] = D[depth] - C[depth];
if (!(AB[depth].isZero() || DC[depth].isZero()))
{
absMultiply_n_power_1p58496_iter(AB[depth], DC[depth], ABDC[depth], depth + 1);
#ifdef DEBUG
absMultiply_square_n(AB[depth], DC[depth], tmp);
assert(tmp == ABDC[depth]);
#endif
if (AB[depth].sign != DC[depth].sign)
absMinus(ret, ABDC[depth], ret);
else absPlusSafe(ret, ABDC[depth], ret);
}
left_shift_in_BASE_system(ret, half_nbits);
left_shift_in_BASE_system(AC[depth], half_nbits << 1);
absPlusSafe(ret, AC[depth], ret);
absPlusSafe(ret, BD[depth], ret);
}
static void swap(const Bignum *&a, const Bignum *&b)
{
const Bignum *t = a;
a = b;
b = t;
}
static void partition(const Bignum &n, Bignum &a, Bignum &b, int nbits)
{
int half_nbits = (nbits + 1) >> 1;
assert(half_nbits == nbits - (nbits >> 1));
if (n.nbits <= half_nbits)
{
a.setZero();
b = n;
return;
}
for (int i = n.nbits - 1, p = (nbits >> 1) - (nbits - n.nbits) - 1; i >= half_nbits; i --, p --)
a.bit[p] = (i < n.nbits ? n.bit[i] : 0);
a.nbits = (nbits >> 1) - (nbits - n.nbits);
while (a.bit[a.nbits - 1] == 0 && a.nbits > 1)
a.nbits --;
for (int i = half_nbits - 1; i >= 0; i --)
b.bit[i] = n.bit[i];
b.nbits = half_nbits;
while (b.bit[b.nbits - 1] == 0 && b.nbits > 1)
b.nbits --;
}
static void left_shift_in_BASE_system(Bignum &ret, int nbits)
{
if (nbits <= 0)
return;
if (ret.isZero())
return;
for (int i = ret.nbits - 1, j = ret.nbits + nbits - 1; i >= 0; i --, j --)
ret.bit[j] = ret.bit[i];
for (int i = nbits - 1; i >= 0; i --)
ret.bit[i] = 0;
ret.nbits += nbits;
}
static void absMultiply_n_power_1p58496(const Bignum &a, const Bignum &b, Bignum &ret)
{
absMultiply_n_power_1p58496_iter(a, b, ret, 0);
}
/*
* big integer multiplication using Fast Fourier Transform(FFT) algorithm
* Time complexity is O(nlogn)
*/
static void absMultiply_nlogn(const Bignum &a, const Bignum &b, Bignum &ret)
{
}
static void absMultiply(const Bignum &a, const Bignum &b, Bignum &ret)
{
if (a.isZero() || b.isZero())
{
ret.setZero();
return;
}
#ifdef MULTIPLICATION_SLOW
absMultiply_square_n(a, b, ret);
#elif defined(MULTIPLICATION_FASTER)
absMultiply_n_power_1p58496(a, b, ret);
#elif defined(MULTIPLICATION_FASTEST)
#else
absMultiply_square_n(a, b, ret);
#endif
#ifdef DEBUG
assert(ret.bit[len - 1] != 0);
assert(!this->isZero());
#endif
}
/*------------ end of big integer multiplication -------------*/
public:
Bignum(){}
Bignum(long long val)
{
if (val < 0)
{
sign = NEGATIVE;
val = -val;
}
else sign = POSITIVE;
nbits = 1;
if (val == 0)
bit[0] = 0;
else
{
while (val)
{
bit[nbits - 1] = 0;
for (Num_t base = 1; base < BASE && val; base *= 10)
{
bit[nbits - 1] = bit[nbits - 1] + val % 10 * base;
val /= 10;
}
if (val)
nbits ++;
}
}
}
double toDouble() const
{
double ret = 0;
for (int i = nbits - 1; i >= 0; i --)
ret = ret * BASE + bit[i];
if (sign == NEGATIVE)
ret = -ret;
return ret;
}
long double toLongDouble() const
{
long double ret = 0;
for (int i = nbits - 1; i >= 0; i --)
ret = ret * BASE + bit[i];
if (sign == NEGATIVE)
ret = -ret;
return ret;
}
Bignum& fromString(const char *str)
{
const char *begin = str;
str += strlen(str) - 1;
if (!isdigit(*begin))
{
if (*begin== '-')
sign = NEGATIVE;
else sign = POSITIVE;
begin ++;
} else
sign = POSITIVE;
nbits = 1;
bit[0] = 0;
while (str >= begin)
{
bit[nbits - 1] = 0;
for (Num_t base = 1; base < BASE && str >= begin; base *= 10, str --)
{
while (str >= begin && !isdigit(*str))
str --;
if (str >= begin)
bit[nbits - 1] += (*str - '0') * base;
}
if (str >= begin)
nbits ++;
}
return *this;
}
Bignum& operator = (const Bignum &n)
{
nbits = n.nbits;
sign = n.sign;
memcpy(bit, n.bit, sizeof(Num_t) * nbits);
return *this;
}
bool operator == (const Bignum &n) const
{
if (nbits != n.nbits)
return false;
if (sign != n.sign)
return false;
for (int i = 0; i < nbits; i ++)
if (bit[i] != n.bit[i])
return false;
return true;
}
bool operator != (const Bignum &n) const
{
if (nbits != n.nbits)
return true;
if (sign != n.sign)
return true;
for (int i = 0; i < nbits; i ++)
if (bit[i] != n.bit[i])
return true;
return false;
}
bool operator < (const Bignum &n) const
{
if (nbits != n.nbits)
return nbits < n.nbits;
if (sign != n.sign)
return sign == NEGATIVE;
if (sign == POSITIVE)
return absGreater(n, *this);
else return absGreater(*this, n);
}
bool operator > (const Bignum &n) const
{
return n < *this;
}
bool operator <= (const Bignum &n) const
{
if (nbits != n.nbits)
return nbits < n.nbits;
if (sign != n.sign)
return sign == NEGATIVE;
if (sign == POSITIVE)
return absGreaterEqual(n, *this);
else return absGreaterEqual(*this, n);
}
bool operator >= (const Bignum &n) const
{
return n <= *this;
}
// TODO: bit shift is currently not provided.
Bignum& operator + (const Bignum &n) const
{
static Bignum ret;
if (sign != n.sign)
{
bool cmp = absGreaterEqual(*this, n);
if (cmp)
absMinus(*this, n, ret);
else absMinus(n, *this, ret);
if (sign == POSITIVE)
{
if (cmp)
ret.sign = POSITIVE;
else ret.sign = NEGATIVE;
}
else
{
if (!cmp)
ret.sign = POSITIVE;
else ret.sign = NEGATIVE;
}
}
else
{
absPlus(*this, n, ret);
ret.sign = sign;
}
return ret;
}
Bignum& operator += (const Bignum &n)
{
// TODO: don't do like below. that will take down the efficiency
return *this = *this + n;
}
Bignum& operator - (const Bignum &n) const
{
static Bignum ret;
if (sign != n.sign)
{
absPlus(*this, n, ret);
ret.sign = sign;
}
else
{
bool cmp = absGreaterEqual(*this, n);
if (cmp)
absMinus(*this, n, ret);
else absMinus(n, *this, ret);
if (cmp)
{
if (sign == POSITIVE)
ret.sign = POSITIVE;
else ret.sign = NEGATIVE;
}
else
{
if (sign == POSITIVE)
ret.sign = NEGATIVE;
else ret.sign = POSITIVE;
}
}
if (ret.isZero())
ret.sign = POSITIVE;
return ret;
}
Bignum& operator -= (const Bignum &n)
{
// TODO
return *this = *this - n;
}
Bignum& operator * (const Bignum &n) const
{
static Bignum ret;
if (this->isZero() || n.isZero())
{
ret.setZero();
return ret;
}
absMultiply(*this, n, ret);
ret.sign = (sign == n.sign ? POSITIVE : NEGATIVE);
return ret;
}
Bignum& operator *= (const Bignum &n)
{
// TODO
return *this = *this * n;
}
// !!!IMPORTANT!!!
// this division algorithm work iff
// divisor * BASE <= max number a long double can hold,
// approximately 1e4932
// that is the number of digits of the product of divisor
// and BASE in decimal should not exceed 4932
Bignum& operator / (const Bignum &n) const
{
static Bignum ret, remainder, tmp;
ret.setZero();
if (*this < n)
return ret;
remainder.setZero();
long double dremainder = 0, dn = n.toLongDouble();
if (dn < 0)
dn = -dn;
static long double LONG_DOUBLE_MAX = 1e4932L;
assert(dn * BASE <= LONG_DOUBLE_MAX);
for (int i = nbits - 1; i >= 0; i --)
{
if (!remainder.isZero())
{
for (int j = remainder.nbits - 1; j >= 0; j --)
remainder.bit[j + 1] = remainder.bit[j];
remainder.bit[0] = bit[i];
remainder.nbits ++;
}
else
{
remainder.bit[0] = bit[i];
remainder.nbits = 1;
remainder.sign = POSITIVE;
}
Num_t &b = ret.bit[i] = 0;
while (n <= remainder)
{
dremainder = remainder.toLongDouble();
Num_t t = (Num_t)floor(dremainder / dn);
b += t;
#ifdef DEBUG
assert(t < BASE);
#endif
absMultiply(n, t, tmp);
absMinus(remainder, tmp, remainder);
//remainder -= tmp;
}
// while (remainder < 0)
// remainder += n;
}
ret.nbits = nbits;
while (ret.bit[ret.nbits - 1] == 0)
ret.nbits --;
ret.sign = (sign == n.sign ? POSITIVE : NEGATIVE);
return ret;
}
Bignum& operator /= (const Bignum &n)
{
// TODO
return *this = *this / n;
}
Bignum& operator % (const Bignum &n) const
{
// adapted from operator /
static Bignum remainder, tmp;
if (*this < n)
{
remainder = *this;
return remainder;
}
long double dremainder = 0, dn = n.toLongDouble();
if (dn < 0)
dn = -dn;
static long double LONG_DOUBLE_MAX = 1e4932L;
assert(dn * BASE <= LONG_DOUBLE_MAX);
for (int i = nbits - 1; i >= 0; i --)
{
if (!remainder.isZero())
{
for (int j = remainder.nbits - 1; j >= 0; j --)
remainder.bit[j + 1] = remainder.bit[j];
remainder.bit[0] = bit[i];
remainder.nbits ++;
}
else
{
remainder.bit[0] = bit[i];
remainder.nbits = 1;
remainder.sign = POSITIVE;
}
while (n <= remainder)
{
dremainder = remainder.toLongDouble();
Num_t t = (Num_t)floor(dremainder / dn);
#ifdef DEBUG
assert(t < BASE);
#endif
absMultiply(n, t, tmp);
absMinus(remainder, tmp, remainder);
}
}
remainder.sign = (this->sign == n.sign ? POSITIVE : NEGATIVE);
return remainder;
}
Bignum& operator %= (const Bignum &n)
{
// TODO
return *this = *this % n;
}
void print(bool newline = false, FILE *fout = stdout)
{
if (sign == NEGATIVE)
fprintf(fout, "-");
fprintf(fout, "%d", bit[nbits - 1]);
for (int i = nbits - 2; i >= 0; i --)
for (Num_t base = BASE / 10; base; base /= 10)
fprintf(fout, "%d", (bit[i] / base) % 10);
if (newline)
fprintf(fout, "\n");
}
bool scan(FILE *fin = stdin)
{
char *str;
int len = 0;
for (Num_t base = BASE ; base; base /= 10)
len += N_BIT_MAX;
str = new char[len];
if (fscanf(fin, "%s", str) == EOF)
return false;
this->fromString(str);
delete [] str;
return true;
}
#undef MAX
#undef MIN
};
int main()
{
Bignum a, b;
while (a.scan() && b.scan())
(a * b).print(true);
return 0;
}