Karatsuba 分治乘法
算法引入
对于高精度数字的乘法,我们知道朴素的二重循环乘法
//进制:1e9
//下面所有的高精度数均为无符号数,用左开右闭区间表示,使用小端(little-endian)存储
//类似于 STL 迭代器
constexpr uint32_t BASE = 1000000000;
//二重循环乘法
//此函数将 [lbeg, lend) 和 [rbeg, rend) 表示的两个大整数相乘,并返回尾后指针
uint32_t* naive_multiply(
const uint32_t* lbeg, const uint32_t* lend,
const uint32_t* rbeg, const uint32_t* rend,
uint32_t* output) {
const size_t m1 = lend - lbeg, m2 = rend - rbeg;
std::memset(output, 0, sizeof(uint32_t) * (m1 + m2 + 1));
for (size_t i = 0; i < m1; ++i) {
uint32_t carry = 0;
for (size_t j = 0; j < m2; ++j) {
uint64_t temp = 1ULL * lbeg[i] * rbeg[j] + output[i + j] + carry;
carry = temp / BASE;
output[i + j] = temp % BASE;
}
output[i + m2] += carry;
}
uint32_t* output_end = output + m1 + m2;
if (output_end[-1] == 0)
--output_end;
return output_end;
}
如果两个数的数据大小(uint32_t 的数量)分别为 \(m_1\), \(m_2\),则算法的时间复杂度为 \(O(m_1m_2)=O(n^2)\)。但是有没有更快一点的算法呢?这就涉及到这篇文章所说的 Karatsuba 算法。
注意:
- 下文中所有 \(m_1\),\(m_2\) 指两个数的数据规模,\(n=\max(m_1,m_2)\)
算法介绍
Karatsuba 算法利用了分治的原理。首先将两个数 \(U,V\) 各分成两段,\(U_0, U_1, V_0, V_1\),满足 $ U = U_1\beta ^ k + U_0, V = V_1\beta ^ k + V_0 $,其中 $ \beta $ 是使用的进制,\(k\) 一般是任选的,但是当 $ k = \max(m_1, m_2) / 2 $ 时最好。
我们可以尝试展开一些公式:
已知加减法的时间复杂度都是 \(O(n)\),如果我们递归计算 $ U_1V_1,U_0V_0,U_1V_0,U_0V_1 $,这个算法的时间复杂度是
相比于二重循环乘法,复杂度上没有任何改进,常数还增加了不少。所以我们需要想办法减少递归乘法的次数。
Karatsuba 想了这么一个方法:计算 $ (U_0+U_1)(V_0+V_1) $。
而 $ U_0V_0 $ 和 $ U_1V_1 $ 已经被求过了,4 次递归乘法可以变成 3 次。可是乘上 \(\beta^{2k}\) 的操作如何实现?我们观察到,十进制下一个数乘 10 就是把这个数整体向左移一位,二进制下一个数乘 2 就是把这个数在二进制表示下整体向左移一位。我们可以归纳出,一个以 \(\beta\) 为进制的高精度数乘上 \(\beta^k\),就是把它整体向左移动 \(k\) 个块。对于这一步可以使用 std::memmove \(O(n)\) 地处理,也可以通过调整指针等操作 \(O(1)\) 地处理。
所以我们可以得到 Karatsuba 算法:
Karatsuba 算法
输入:两个高精度数 $ U,V $(进制是 $ \beta $,规模分别是 \(m_1\) 和 \(m_2\))。
输出:一个高精度数,满足这个数是输入两个数的乘积。
步骤:
- 令 $ k = \max(m_1, m_2) / 2$
- 将两个数 \(U,V\) 分成 \(U_0,U_1,V_0,V_1\) ,条件见上。
- 令 \(W_0=U_0V_0\)
- 令 \(W_2=U_1V_1\)
- 令 \(W_1=(U_0+U_1)(V_0+V_1)-W_0-W_2\)
- 计算 \(W=W_0+W_1\cdot \beta^k+W_2\cdot \beta^{2k}\),\(W\) 即为答案
计算 \(U_0+U_1\) ,\(V_0+V_1\) 需要辅助内存,我们可以先计算 \(W_1\),把 \(W_0\) 和 \(W_2\) 当作辅助内存,这样可以减少不必要的内存分配。
因为加减法的时间复杂度是 \(O(n)\),我们可以算出 Karatsuba 算法的时间复杂度是:
根据数据规模选择算法
Karatsuba 的时间复杂度比普通的二重循环乘法要更优一些,但是并非所有情况我们都使用这个算法。在数据规模很小的情况下,二重循环乘法要比 Karatsuba 算法更快一些。所以我们可以设定一个界限,当数据规模大于这个界限时使用 Karatsuba 算法,否则使用二重循环乘法。这样对算法效率的提升是非常显著的,尤其是对于复杂度更优的其他算法,如 Toom-Cook, FFT 等(这篇文章暂不做讨论),这些算法的常数都非常高,数据量小的情况下使用它们不会获得很好的效果。界限值与具体实现有非常大的关系。例如,Java BigInteger 的实现中,如果数据规模大于 80,则使用 Karatsuba 算法;gmplib 的实现中,如果数据规模大于 20 到 30,则使用 Karatsuba 算法;gmplib 甚至对不同的处理器设置了不同的界限。根据数据规模选择算法这个思想在实际应用中是非常重要的,它可以很好地提升速度。
在这个实现中,我将界限值 KARATSUBA_THRESHOLD 设置为 100。
一个细节
-
为什么令 \(k=\max(m_1, m_2)/2\) 而非 \(\min(m_1, m_2)/2\)?
答:如果采用后者,无法保证每次递归乘法的规模都小于等于 \(n/2\),尤其是当数据规模相差较大时。比如看这个例子:
\[U=9,123,456,789 \]\[V=5,666 \]-
当 \(k=\max(m_1,m_2)=5\)
\[U_0=56,789 \]\[U_1=91,234 \]\[V_0=5,666 \]\[V_1=0 \] -
当 \(k=\min(m_1,m_2)=2\)
\[U_0=89 \]\[U_1=91,234,567 \]\[V_0=66 \]\[V_1=56 \]
采用前者的划分方案时,每一次递归乘法的数据规模都小于等于 \(n/2\),而采用后者则不可以。
-
代码实现
洛谷上 P1919 就是高精度乘法的模板题。此题应用 FFT 解决,但是 Karatsuba + O2 optimize 勉强可以卡过,我的实现中每个点跑了 1.1s 左右。下面就是这道题的代码,对于 Karatsuba 算法更多的细节可以看代码中的注释。
#include <iostream>
#include <cstdio>
#include <cstdint>
#include <climits>
#include <algorithm>
#include <cstring>
constexpr uint32_t BASE = 1000000000;
constexpr size_t KARATSUBA_THRESHOLD = 100;
using std::min;
using std::max;
//这份代码直接使用 m, n 而不是 m1, m2 代表数据规模
//因为 m, n 比 m1, m2 在输入上更方便
//二重循环乘法
//此函数将 [lbeg, lend) 和 [rbeg, rend) 表示的两个大整数相乘,并返回尾后指针
uint32_t* naive_multiply(
const uint32_t* lbeg, const uint32_t* lend,
const uint32_t* rbeg, const uint32_t* rend,
uint32_t* output) {
const size_t m = lend - lbeg, n = rend - rbeg;
//对于两个数中有一个或者两个数为零的情况特殊判断
if (m == 0 || n == 0 || (m == 1 && *lbeg == 0) || (n == 1 && *rbeg == 0)) {
output[0] = 0;
return output + 1;
}
std::memset(output, 0, sizeof(uint32_t) * (m + n));
for (size_t i = 0; i < m; ++i) {
uint64_t carry = 0;
for (size_t j = 0; j < n; ++j) {
carry += 1ULL * lbeg[i] * rbeg[j] + output[i + j];
output[i + j] = carry % BASE;
carry /= BASE;
}
output[i + n] = carry;
}
uint32_t* output_end = output + m + n;
if (output_end[-1] == 0)
--output_end;
return output_end;
}
struct ptrpair {
uint32_t* beg;
uint32_t* end;
};
//将 [lbeg, lend), [rbeg, rend) 表示的两个大整数相加
//返回尾后指针
uint32_t* hprec_add(
const uint32_t* lbeg, const uint32_t* lend,
const uint32_t* rbeg, const uint32_t* rend,
uint32_t* output) {
if (lend - lbeg < rend - rbeg)
std::swap(lbeg, rbeg), std::swap(lend, rend);
uint32_t carry = 0;
for (; rbeg != rend; ++lbeg, ++rbeg, ++output) {
carry += *lbeg + *rbeg;
*output = carry >= BASE ? carry - BASE : carry;
carry = carry >= BASE;
}
for (; lbeg != lend; ++lbeg, ++output) {
carry += *lbeg;
*output = carry >= BASE ? carry - BASE : carry;
carry = carry >= BASE;
}
if (carry)
*output++ = carry;
return output;
}
//将 [lbeg, lend), [rbeg, rend) 表示的两个数相减,要求左面的数大于等于右面的数
//返回尾后指针
uint32_t* hprec_sub(
const uint32_t* lbeg, const uint32_t* lend,
const uint32_t* rbeg, const uint32_t* rend,
uint32_t* output) {
const size_t m = lend - lbeg, n = rend - rbeg;
const uint32_t* lmin = lbeg + min(m, n);
uint32_t carry = 0;
uint32_t* output_end = output;
for (; lbeg != lmin; ++lbeg, ++rbeg, ++output_end) {
carry = BASE + *lbeg - *rbeg - carry;
*output_end = carry >= BASE ? carry - BASE : carry;
carry = carry < BASE;
}
for (size_t i = min(m, n); i < max(m, n); ++i, ++lbeg, ++rbeg, ++output_end) {
carry = BASE + (i < m ? *lbeg : 0) - (i < n ? *rbeg : 0) - carry;
*output_end = carry >= BASE ? carry - BASE : carry;
carry = carry < BASE;
}
//去掉前导零
//如果结果为零,下面的循环得到的结果可能与预想不同,
//应先将第一个 block 设置为一个非零的数以避免此问题
carry = output[0];
output[0] = 0x3f3f3f3f;
for (; output_end[-1] == 0; --output_end);
output[0] = carry;
//上面的几条语句也可用
//for (; output_end[-1] == 0 && output_end != output + 1; --output_end)
//来代替,但是上面的语句可以减少几次判断(output_end != output + 1)的次数
return output_end;
}
//使用 Karatsuba 算法将 [lbeg, lend), [rbeg, rend) 表示的两个大整数相乘
//返回尾后指针
uint32_t* karatsuba_multiply(
const uint32_t* lbeg, const uint32_t* lend,
const uint32_t* rbeg, const uint32_t* rend,
uint32_t* output) {
size_t m = lend - lbeg, n = rend - rbeg;
//数据规模较小时,直接使用二重循环乘法
if (m < KARATSUBA_THRESHOLD || n < KARATSUBA_THRESHOLD)
return naive_multiply(lbeg, lend, rbeg, rend, output);
//对于两个数中有一个或者两个数为零的情况特殊判断
if (m == 0 || n == 0 || (m == 1 && *lbeg == 0) || (n == 1 && *rbeg == 0)) {
output[0] = 0;
return output + 1;
}
//不妨假设 m > n
if (m < n)
std::swap(lbeg, rbeg), std::swap(lend, rend), std::swap(m, n);
size_t mid = m >> 1;
const uint32_t* lmid = lbeg + mid;
const uint32_t* rmid = rbeg + mid;
//如果小的数特别小 (n < m/2),导致 rmid > rend ,就要进行调整
rmid = rmid > rend ? rend : rmid;
//W0, W1, W2 的空间
//其中 W0 可以直接使用 output 而无需单独分配内存
ptrpair W0{output}, W1{new uint32_t[m * 2 + 10]}, W2{new uint32_t[m * 2 + 10]};
std::memset(W1.beg, 0, sizeof(uint32_t) * mid);
std::memset(W2.beg, 0, sizeof(uint32_t) * mid * 2);
W1.beg += mid;
W2.beg += mid * 2;
//W0 = U0 + U1
W0.end = hprec_add(lbeg, lmid, lmid, lend, W0.beg);
//W2 = V0 + V1
W2.end = hprec_add(rbeg, rmid, rmid, rend, W2.beg);
//递归:W1 = W0 * W2 = (U0 + U1)(V0 + V1)
W1.end = karatsuba_multiply(W0.beg, W0.end, W2.beg, W2.end, W1.beg);
//递归:W0 = U0 * V0
W0.end = karatsuba_multiply(lbeg, lmid, rbeg, rmid, W0.beg);
//递归:W2 = U1 * V1
W2.end = karatsuba_multiply(lmid, lend, rmid, rend, W2.beg);
//W1 -= (W0 + W2)
W1.end = hprec_sub(W1.beg, W1.end, W0.beg, W0.end, W1.beg);
W1.end = hprec_sub(W1.beg, W1.end, W2.beg, W2.end, W1.beg);
//W1 *= beta^k
W1.beg -= mid;
//W2 *= beta^2k
W2.beg -= mid * 2;
uint32_t* output_end = nullptr;
//W = W0 + W1 + W2,W 即为答案
output_end = hprec_add(W0.beg, W0.end, W1.beg, W1.end, output);
output_end = hprec_add(output, output_end, W2.beg, W2.end, output);
delete[] W1.beg;
delete[] W2.beg; //删除辅助内存
return output_end;
}
//字符串转高精度,返回尾后指针
uint32_t* serialize(char* beg, char* end, uint32_t* output) {
uint32_t block = 0;
uint32_t pow10[10] = {};
pow10[0] = 1;
for (int i = 1; i <= 9; ++i)pow10[i] = pow10[i-1] * 10;
for (size_t i = 0; i < end - beg; ++i) {
block += pow10[i % 9] * (beg[end - beg - i - 1] - '0');
//该 block 输入完毕
if ((i+1) % 9 == 0) {
*output++ = block;
block = 0;
}
}
if (block)
*output++ = block;
return output;
}
char numstr[2000004];
uint32_t numa[111120], numb[111120], numc[222240];
int main() {
char* num1 = numstr, *num2 = numstr + 1000002;
scanf("%s", num1);
scanf("%s", num2);
ptrpair a{numa}, b{numb};
//计算
uint32_t* result = karatsuba_multiply(
numa, serialize(num1, num1 + strlen(num1), numa),
numb, serialize(num2, num2 + strlen(num2), numb),
numc
);
//高精度转字符串(逆序)
char* chptr = num1;
for (uint32_t* p = numc; p != result - 1; ++p) {
for (int i = 0; i < 9; ++i) {
*chptr++ = *p % 10 + '0';
*p /= 10;
}
}
while (result[-1]) {
*chptr++ = result[-1] % 10 + '0';
result[-1] /= 10;
}
while (chptr[-1] == '0' && chptr != num1)--chptr;
std::reverse(num1, chptr); //反转
*chptr++ = '\0';
//输出结果
puts(num1);
return 0;
}

浙公网安备 33010602011771号