Karatsuba 分治乘法

算法引入

对于高精度数字的乘法,我们知道朴素的二重循环乘法

//进制:1e9
//下面所有的高精度数均为无符号数,用左开右闭区间表示,使用小端(little-endian)存储
//类似于 STL 迭代器
constexpr uint32_t BASE = 1000000000;

//二重循环乘法
//此函数将 [lbeg, lend) 和 [rbeg, rend) 表示的两个大整数相乘,并返回尾后指针
uint32_t* naive_multiply(
	const uint32_t* lbeg, const uint32_t* lend,
	const uint32_t* rbeg, const uint32_t* rend,
	uint32_t* output) {
	const size_t m1 = lend - lbeg, m2 = rend - rbeg;
	std::memset(output, 0, sizeof(uint32_t) * (m1 + m2 + 1));
	for (size_t i = 0; i < m1; ++i) {
		uint32_t carry = 0;
		for (size_t j = 0; j < m2; ++j) {
			uint64_t temp = 1ULL * lbeg[i] * rbeg[j] + output[i + j] + carry;
			carry = temp / BASE;
			output[i + j] = temp % BASE;
		}
		output[i + m2] += carry;
	}
	
	uint32_t* output_end = output + m1 + m2;
	if (output_end[-1] == 0)
		--output_end;
	return output_end;
}

如果两个数的数据大小(uint32_t 的数量)分别为 \(m_1\), \(m_2\),则算法的时间复杂度为 \(O(m_1m_2)=O(n^2)\)。但是有没有更快一点的算法呢?这就涉及到这篇文章所说的 Karatsuba 算法。

注意:

  • 下文中所有 \(m_1\)\(m_2\) 指两个数的数据规模,\(n=\max(m_1,m_2)\)

算法介绍

Karatsuba 算法利用了分治的原理。首先将两个数 \(U,V\) 各分成两段,\(U_0, U_1, V_0, V_1\),满足 $ U = U_1\beta ^ k + U_0, V = V_1\beta ^ k + V_0 $,其中 $ \beta $ 是使用的进制,\(k\) 一般是任选的,但是当 $ k = \max(m_1, m_2) / 2 $ 时最好。
我们可以尝试展开一些公式:

\[UV=(U_1\beta^k+U_0)(V_1\beta^k+V_0) \]

\[= U_1V_1\cdot \beta^{2k}+(U_1V_0+U_0V_1)\cdot \beta^{2k}+U_0V_0 \]

已知加减法的时间复杂度都是 \(O(n)\),如果我们递归计算 $ U_1V_1,U_0V_0,U_1V_0,U_0V_1 $,这个算法的时间复杂度是

\[T(n)=4T(n/2)+O(n) \implies T(n)=O(n^2) \]

相比于二重循环乘法,复杂度上没有任何改进,常数还增加了不少。所以我们需要想办法减少递归乘法的次数。
Karatsuba 想了这么一个方法:计算 $ (U_0+U_1)(V_0+V_1) $。

\[(U_0+U_1)(V_0+V_1)=U_0V_0+U_0V_1+V_0U_1+U_1V_1 \]

\[U_1V_0+U_0V_1=(U_0+U_1)(V_0+V_1)-U_0V_0-U_1V_1 \]

而 $ U_0V_0 $ 和 $ U_1V_1 $ 已经被求过了,4 次递归乘法可以变成 3 次。可是乘上 \(\beta^{2k}\) 的操作如何实现?我们观察到,十进制下一个数乘 10 就是把这个数整体向左移一位,二进制下一个数乘 2 就是把这个数在二进制表示下整体向左移一位。我们可以归纳出,一个以 \(\beta\) 为进制的高精度数乘上 \(\beta^k\),就是把它整体向左移动 \(k\) 个块。对于这一步可以使用 std::memmove \(O(n)\) 地处理,也可以通过调整指针等操作 \(O(1)\) 地处理。

所以我们可以得到 Karatsuba 算法:


Karatsuba 算法

输入:两个高精度数 $ U,V $(进制是 $ \beta $,规模分别是 \(m_1\)\(m_2\))。

输出:一个高精度数,满足这个数是输入两个数的乘积。

步骤

  1. 令 $ k = \max(m_1, m_2) / 2$
  2. 将两个数 \(U,V\) 分成 \(U_0,U_1,V_0,V_1\) ,条件见上。
  3. \(W_0=U_0V_0\)
  4. \(W_2=U_1V_1\)
  5. \(W_1=(U_0+U_1)(V_0+V_1)-W_0-W_2\)
  6. 计算 \(W=W_0+W_1\cdot \beta^k+W_2\cdot \beta^{2k}\)\(W\) 即为答案

计算 \(U_0+U_1\)\(V_0+V_1\) 需要辅助内存,我们可以先计算 \(W_1\),把 \(W_0\)\(W_2\) 当作辅助内存,这样可以减少不必要的内存分配。

因为加减法的时间复杂度是 \(O(n)\),我们可以算出 Karatsuba 算法的时间复杂度是:

\[T(n)=3T(n/2)+O(n) \implies T(n)=O(n^{\log_2{3}})=O(n^{1.585}) \]

根据数据规模选择算法

Karatsuba 的时间复杂度比普通的二重循环乘法要更优一些,但是并非所有情况我们都使用这个算法。在数据规模很小的情况下,二重循环乘法要比 Karatsuba 算法更快一些。所以我们可以设定一个界限,当数据规模大于这个界限时使用 Karatsuba 算法,否则使用二重循环乘法。这样对算法效率的提升是非常显著的,尤其是对于复杂度更优的其他算法,如 Toom-Cook, FFT 等(这篇文章暂不做讨论),这些算法的常数都非常高,数据量小的情况下使用它们不会获得很好的效果。界限值与具体实现有非常大的关系。例如,Java BigInteger 的实现中,如果数据规模大于 80,则使用 Karatsuba 算法;gmplib 的实现中,如果数据规模大于 20 到 30,则使用 Karatsuba 算法;gmplib 甚至对不同的处理器设置了不同的界限。根据数据规模选择算法这个思想在实际应用中是非常重要的,它可以很好地提升速度。

在这个实现中,我将界限值 KARATSUBA_THRESHOLD 设置为 100。

一个细节

  • 为什么令 \(k=\max(m_1, m_2)/2\) 而非 \(\min(m_1, m_2)/2\)

    答:如果采用后者,无法保证每次递归乘法的规模都小于等于 \(n/2\),尤其是当数据规模相差较大时。比如看这个例子:

    \[U=9,123,456,789 \]

    \[V=5,666 \]

    • \(k=\max(m_1,m_2)=5\)

      \[U_0=56,789 \]

      \[U_1=91,234 \]

      \[V_0=5,666 \]

      \[V_1=0 \]

    • \(k=\min(m_1,m_2)=2\)

      \[U_0=89 \]

      \[U_1=91,234,567 \]

      \[V_0=66 \]

      \[V_1=56 \]

    采用前者的划分方案时,每一次递归乘法的数据规模都小于等于 \(n/2\),而采用后者则不可以。

代码实现

洛谷上 P1919 就是高精度乘法的模板题。此题应用 FFT 解决,但是 Karatsuba + O2 optimize 勉强可以卡过,我的实现中每个点跑了 1.1s 左右。下面就是这道题的代码,对于 Karatsuba 算法更多的细节可以看代码中的注释。

#include <iostream>
#include <cstdio>
#include <cstdint>
#include <climits>
#include <algorithm>
#include <cstring>

constexpr uint32_t BASE = 1000000000;
constexpr size_t KARATSUBA_THRESHOLD = 100;

using std::min;
using std::max;

//这份代码直接使用 m, n 而不是 m1, m2 代表数据规模
//因为 m, n 比 m1, m2 在输入上更方便
 
//二重循环乘法
//此函数将 [lbeg, lend) 和 [rbeg, rend) 表示的两个大整数相乘,并返回尾后指针
uint32_t* naive_multiply(
	const uint32_t* lbeg, const uint32_t* lend,
	const uint32_t* rbeg, const uint32_t* rend,
	uint32_t* output) {
	const size_t m = lend - lbeg, n = rend - rbeg;

	//对于两个数中有一个或者两个数为零的情况特殊判断
	if (m == 0 || n == 0 || (m == 1 && *lbeg == 0) || (n == 1 && *rbeg == 0)) {
		output[0] = 0;
		return output + 1;
	}
	
	std::memset(output, 0, sizeof(uint32_t) * (m + n));
	for (size_t i = 0; i < m; ++i) {
		uint64_t carry = 0;
		for (size_t j = 0; j < n; ++j) {
			carry += 1ULL * lbeg[i] * rbeg[j] + output[i + j];
			output[i + j] = carry % BASE;
			carry /= BASE;
		}
		output[i + n] = carry;
	}
	
	uint32_t* output_end = output + m + n;
	if (output_end[-1] == 0)
		--output_end;
	return output_end;
}

struct ptrpair {
	uint32_t* beg;
	uint32_t* end;
};

//将 [lbeg, lend), [rbeg, rend) 表示的两个大整数相加
//返回尾后指针 
uint32_t* hprec_add(
	const uint32_t* lbeg, const uint32_t* lend,
	const uint32_t* rbeg, const uint32_t* rend,
	uint32_t* output) {
	if (lend - lbeg < rend - rbeg)
		std::swap(lbeg, rbeg), std::swap(lend, rend);
	
	uint32_t carry = 0;
	for (; rbeg != rend; ++lbeg, ++rbeg, ++output) {
		carry += *lbeg + *rbeg;
		*output = carry >= BASE ? carry - BASE : carry;
		carry = carry >= BASE;
	}
	
	for (; lbeg != lend; ++lbeg, ++output) {
		carry += *lbeg;
		*output = carry >= BASE ? carry - BASE : carry;
		carry = carry >= BASE;
	}
	
	if (carry)
		*output++ = carry;
	return output;
}

//将 [lbeg, lend), [rbeg, rend) 表示的两个数相减,要求左面的数大于等于右面的数 
//返回尾后指针 
uint32_t* hprec_sub(
	const uint32_t* lbeg, const uint32_t* lend,
	const uint32_t* rbeg, const uint32_t* rend,
	uint32_t* output) {
	const size_t m = lend - lbeg, n = rend - rbeg;
	const uint32_t* lmin = lbeg + min(m, n);
	uint32_t carry = 0;
	uint32_t* output_end = output;
	
	for (; lbeg != lmin; ++lbeg, ++rbeg, ++output_end) {
		carry = BASE + *lbeg - *rbeg - carry;
		*output_end = carry >= BASE ? carry - BASE : carry;
		carry = carry < BASE;
	}
	
	for (size_t i = min(m, n); i < max(m, n); ++i, ++lbeg, ++rbeg, ++output_end) {
		carry = BASE + (i < m ? *lbeg : 0) - (i < n ? *rbeg : 0) - carry;
		*output_end = carry >= BASE ? carry - BASE : carry;
		carry = carry < BASE;
	}
	
	
	
	//去掉前导零 
	//如果结果为零,下面的循环得到的结果可能与预想不同,
	//应先将第一个 block 设置为一个非零的数以避免此问题
	carry = output[0];
	output[0] = 0x3f3f3f3f; 
	for (; output_end[-1] == 0; --output_end);
	output[0] = carry;
	
	//上面的几条语句也可用
	//for (; output_end[-1] == 0 && output_end != output + 1; --output_end)
	//来代替,但是上面的语句可以减少几次判断(output_end != output + 1)的次数 
	
	return output_end;
}

//使用 Karatsuba 算法将 [lbeg, lend), [rbeg, rend) 表示的两个大整数相乘 
//返回尾后指针 
uint32_t* karatsuba_multiply(
	const uint32_t* lbeg, const uint32_t* lend,
	const uint32_t* rbeg, const uint32_t* rend,
	uint32_t* output) {
		
	size_t m = lend - lbeg, n = rend - rbeg;

	//数据规模较小时,直接使用二重循环乘法
	if (m < KARATSUBA_THRESHOLD || n < KARATSUBA_THRESHOLD)
		return naive_multiply(lbeg, lend, rbeg, rend, output);

	//对于两个数中有一个或者两个数为零的情况特殊判断
	if (m == 0 || n == 0 || (m == 1 && *lbeg == 0) || (n == 1 && *rbeg == 0)) {
		output[0] = 0;
		return output + 1;
	}
	
	//不妨假设 m > n 
	if (m < n)
		std::swap(lbeg, rbeg), std::swap(lend, rend), std::swap(m, n);
	
	size_t mid = m >> 1;
	const uint32_t* lmid = lbeg + mid;
	const uint32_t* rmid = rbeg + mid;
	
	//如果小的数特别小 (n < m/2),导致 rmid > rend ,就要进行调整 
	rmid = rmid > rend ? rend : rmid;
	
	//W0, W1, W2 的空间 
	//其中 W0 可以直接使用 output 而无需单独分配内存 
	ptrpair W0{output}, W1{new uint32_t[m * 2 + 10]}, W2{new uint32_t[m * 2 + 10]};
	std::memset(W1.beg, 0, sizeof(uint32_t) * mid);
	std::memset(W2.beg, 0, sizeof(uint32_t) * mid * 2);
	W1.beg += mid;
	W2.beg += mid * 2;
	
	//W0 = U0 + U1
	W0.end = hprec_add(lbeg, lmid, lmid, lend, W0.beg);

	//W2 = V0 + V1
	W2.end = hprec_add(rbeg, rmid, rmid, rend, W2.beg);

	//递归:W1 = W0 * W2 = (U0 + U1)(V0 + V1)
	W1.end = karatsuba_multiply(W0.beg, W0.end, W2.beg, W2.end, W1.beg);

	//递归:W0 = U0 * V0
	W0.end = karatsuba_multiply(lbeg, lmid, rbeg, rmid, W0.beg);

	//递归:W2 = U1 * V1
	W2.end = karatsuba_multiply(lmid, lend, rmid, rend, W2.beg);

	//W1 -= (W0 + W2)
	W1.end = hprec_sub(W1.beg, W1.end, W0.beg, W0.end, W1.beg);
	W1.end = hprec_sub(W1.beg, W1.end, W2.beg, W2.end, W1.beg);
	
	//W1 *= beta^k
	W1.beg -= mid;

	//W2 *= beta^2k
	W2.beg -= mid * 2;
	uint32_t* output_end = nullptr;

	//W = W0 + W1 + W2,W 即为答案
	output_end = hprec_add(W0.beg, W0.end, W1.beg, W1.end, output);
	output_end = hprec_add(output, output_end, W2.beg, W2.end, output);
	
	delete[] W1.beg;
	delete[] W2.beg; //删除辅助内存 
	return output_end;
}

//字符串转高精度,返回尾后指针 
uint32_t* serialize(char* beg, char* end, uint32_t* output) {
	uint32_t block = 0;
	uint32_t pow10[10] = {};
	pow10[0] = 1;
	for (int i = 1; i <= 9; ++i)pow10[i] = pow10[i-1] * 10;
	
	for (size_t i = 0; i < end - beg; ++i) {
		block += pow10[i % 9] * (beg[end - beg - i - 1] - '0');
		
		//该 block 输入完毕 
		if ((i+1) % 9 == 0) {
			*output++ = block;
			block = 0;
		}
	}
	if (block)
		*output++ = block;
	return output;
}

char numstr[2000004];
uint32_t numa[111120], numb[111120], numc[222240];
int main() {
	char* num1 = numstr, *num2 = numstr + 1000002;
	scanf("%s", num1);
	scanf("%s", num2);
	ptrpair a{numa}, b{numb};
	
	//计算 
	uint32_t* result = karatsuba_multiply(
		numa, serialize(num1, num1 + strlen(num1), numa),
		numb, serialize(num2, num2 + strlen(num2), numb),
		numc
	);
	
	//高精度转字符串(逆序)
	char* chptr = num1;
	for (uint32_t* p = numc; p != result - 1; ++p) {
		for (int i = 0; i < 9; ++i) {
			*chptr++ = *p % 10 + '0';
			*p /= 10;
		}
	}
	while (result[-1]) {
		*chptr++ = result[-1] % 10 + '0';
		result[-1] /= 10;
	}
	
	while (chptr[-1] == '0' && chptr != num1)--chptr;
	std::reverse(num1, chptr); //反转 
	*chptr++ = '\0';
	
	//输出结果 
	puts(num1);
	return 0;
}
posted @ 2021-11-06 16:01  Messier51  阅读(1206)  评论(0)    收藏  举报