乘法快速幂

例题:P1226 【模板】快速幂

给定三个整数 \(a,b,p\),求 \(a^b \bmod p\)\(0 \le a,b \le 2^{31}, \ a+b \gt 0, \ 2 \le p \lt 2^{31}\)

最朴素的想法是直接用一个循环,将 \(a\) 连乘 \(b\) 次,每次乘法后都对 \(p\) 取模。但题目给出的数据范围中,\(b\) 的值可以非常大,如果直接循环 \(b\) 次,计算量过大,会导致超时。

因此需要一种更高效的算法来处理大指数的幂运算,这就是快速幂算法。

快速幂的核心思想是二进制拆分,它将指数 \(b\) 拆解为二进制形式,从而显著减少乘法次数。例如,计算 \(a^{10}\),可以将指数 \(10\) 转换为二进制的 \(1010\)\(10 = 8 + 2 = 1 \cdot 2^3 + 0 \cdot 2^2 + 1 \cdot 2^1 + 0 \cdot 2^0\),所以,\(a^{10} = a^{8+2} = a^8 \cdot a^2\)

这样,就不需要计算 \(a^1, a^2, a^3, \dots, a^{10}\),而只需要计算 \(a^2, a^4, a^8\) 这些 \(a\)\(2^k\) 次幂。\(a^2\) 可以由 \(a^1 \cdot a^1\) 得到,\(a^4\) 可以由 \(a^2 \cdot a^2\) 得到,\(a^8\) 可以由 \(a^4 \cdot a^4\) 得到,这种方式可以很快地得到所有需要的 \(a^{2^k}\) 项。

最后,根据 \(b\) 的二进制表示中哪些位是 \(1\),将对应的 \(a^{2^k}\) 项乘起来即可。整个过程中的所有乘法都进行取模运算,以防止中间结果溢出。

参考代码
#include <cstdio>

/**
 * @brief 快速幂函数,计算 (x^y) % mod 的值
 * 
 * @param x 底数
 * @param y 指数
 * @param mod 模数
 * @return int 计算结果
 */
int quickpow(int x, int y, int mod) {
    // 初始化结果为 1。任何数的0次幂都是1,这是累乘的初始值。
    int res = 1;
    // 当指数 y 大于 0 时,循环处理
    while (y > 0) {
        // 判断 y 的二进制最低位是否为 1。
        // 如果是 1,说明当前位的权重需要乘入结果。
        if (y % 2 == 1) {
            // 将当前底数 x 乘到结果 res 中。
            // 使用 1ll 将 res 转换为 long long 类型,防止 res * x 的中间结果溢出 int。
            res = 1ll * res * x % mod;
        }
        // 底数自身平方,用于下一轮计算。
        // x 的值依次变为 a^1, a^2, a^4, a^8, ...
        // 同样使用 1ll 防止 x * x 的中间结果溢出。
        x = 1ll * x * x % mod;
        // 指数 y 右移一位(相当于除以 2),处理下一位。
        y /= 2;
    }
    // 返回最终计算结果
    return res;
}

int main()
{
    int a, b, p;
    // 读取输入的 a, b, p
    scanf("%d%d%d", &a, &b, &p);
    // 调用快速幂函数并按格式输出结果
    printf("%d^%d mod %d=%d\n", a, b, p, quickpow(a, b, p));
    return 0;
}

快速幂也可以通过递归函数来实现,这个思路建立在两个简单的数学式上:

  1. 如果指数 \(b\) 是偶数,那么 \(a^b = a^{b/2} \cdot a^{b/2} = (a^{b/2})^2\)
  2. 如果指数 \(b\) 是奇数,那么 \(a^{b-1} \cdot a = (a^{(b-1)/2})^2 \cdot a\)。由于 \(b/2\) 在整数除法中等于 \((b-1)/2\),所以也可以写成 \(a^b = (a^{b/2})^2 \cdot a\)

递归函数正是基于这个思想:

  • 递归的“递”过程:不断地将指数除以 \(2\),直到指数为 \(0\),这是递归的终止条件。
  • 递归的“归”过程:在从最深层返回时,根据当前指数的奇偶性,计算并返回当前层的结果。

这是一种分治的策略,将计算 \(a^b\) 的问题,在一次递归后缩小为计算 \(a^{b/2}\) 的问题,问题规模减半,因此算法的时间复杂度为 \(O(\log b)\),效率非常高。

参考代码
#include <cstdio>

/**
 * @brief 快速幂的递归实现,计算 (x^y) % mod 的值
 * 
 * @param x 底数
 * @param y 指数
 * @param mod 模数
 * @return int 计算结果
 */
int quickpow(int x, int y, int mod) {
    // 递归终止条件:任何数的 0 次幂都是 1
    if (y == 0) {
        return 1;
    }

    // 递归调用,计算 x^(y/2) % mod 的值
    // 通过将指数减半,不断缩小问题规模
    int tmp = quickpow(x, y / 2, mod);

    // 计算 (x^(y/2))^2 % mod 的值
    // 这是 y 为偶数时的结果
    // 使用 1ll 将 tmp 转换为 long long,防止 tmp * tmp 的中间结果溢出 int
    int res = 1ll * tmp * tmp % mod;

    // 如果 y 是奇数,还需要额外乘以一个 x
    // 因为 x^y = (x^(y/2))^2 * x
    if (y % 2 == 1) {
        res = 1ll * res * x % mod;
    }

    // 返回当前层计算的结果
    return res;
}

int main()
{
    int a, b, p;
    // 读取输入的 a, b, p
    scanf("%d%d%d", &a, &b, &p);
    // 调用快速幂函数并按格式输出结果
    printf("%d^%d mod %d=%d\n", a, b, p, quickpow(a, b, p));
    return 0;
}

选择题:现在用如下代码来计算 \(x^n\),其时间复杂度为?

double quick_power(double x, unsigned n) {
    if (n == 0) return 1;
    if (n == 1) return x;
    return quick_power(x, n / 2) * quick_power(x, n / 2) * ((n & 1) ? x : 1);
}
  • A. \(O(n)\)
  • B. \(O(1)\)
  • C. \(O(\log n)\)
  • D. \(O(n \log n)\)
答案

这道题的正确答案是 A

这是一个典型的递归问题,但代码的写法存在效率陷阱。

分析一下函数调用的过程:quick_power(x, n) 会调用两次 quick_power(x, n / 2),而每个 quick_power(x, n / 2) 又会调用两次 quick_power(n / 4)

以此类推,形成的递归过程大致如下(以 n=8 为例):

image

在每一层,函数调用的数量都翻倍。递归的深度是 \(O(\log n)\),但在最底层 \(n=1\) 时,总共大约有 \(n/2\) 个节点,整个递归的节点总数大约是 \(1+2+4 + \cdots + 2^{\log_2 n} \approx n\)

由于每次函数调用都包含常数时间的乘法操作,所以总的时间复杂度与调用次数成正比,即 \(O(n)\)


这是一种效率较低的快速幂实现,标准高效的快速幂算法会先计算一次 quick_power(x, n / 2) 并将其结果存入一个临时变量,然后复用这个结果。那样实现时间复杂度才是 \(O(\log n)\),因为这样每次只产生一个递归调用,递归深度为 \(O(\log n)\)

例题:P10446 64位整数乘法

如果直接在程序中计算 a * b,其结果最大可达 \(10^{18} \times 10^{18} = 10^{36}\)。C++ 中标准的 64 为整型 long long(最大值约为 \(9 \times 10^{18}\))或 unsigned long long(最大值约 \(1.8 \times 10^{19}\))都无法存储这样大的中间结果,会导致数据溢出,从而得到错误的答案,因此需要避免直接计算 a * b

二进制拆分法(“龟速乘”)

这个方法的核心思想是将乘法运算 a * b 巧妙地转化为一系列加法运算,因为 a * b 等于 ba 相加。直接循环 b 次相加会超时,但可以利用 b 的二进制表示来加速这个过程,其思想与“快速幂”非常相似。

参考代码
#include <cstdio>
using ll = long long;

// 函数用于计算 (a * b) % p,防止中间结果溢出
// 使用二进制拆分法(类似快速幂)
ll mul(ll a, ll b, ll p) {
    ll res = 0;
    a %= p; // 先对a取模,缩小范围
    while (b > 0) {
        // 如果b的当前最低位是1,则把当前的a累加到结果中
        if (b & 1) {
            res = (res + a) % p;
        }
        // a翻倍,为计算b的下一位做准备
        a = (a + a) % p;
        // b右移一位,处理下一位
        b >>= 1;
    }
    return res;
}

int main() {
    ll a, b, p;
    scanf("%lld%lld%lld", &a, &b, &p);
    printf("%lld\n", mul(a, b, p));
    return 0;
}

在整个计算过程中,res + aa + a 的中间值最大不会超过 2 * p - 2,因为 ares 都是对 p 取模后的结果。由于 p 最大为 \(10^{18}\)2 * p 不会超过 long long 的表示范围,因此这种方法可以有效避免溢出。此算法的复杂度为 \(O(\log b)\),非常高效。

习题:P1045 [NOIP 2003 普及组] 麦森数

解题思路

\(2^P\) 是一个极其巨大的数字,直接计算并存储全部位数是不现实的。然而,题目只要求最后 500 位,这意味着只需要保留低 500 位进行高精度运算即可。

对于一个正整数 \(x\),其十进制位数 \(D\) 可以通过公式 \(D = \lfloor \log_{10} \rfloor + 1\) 计算。而本题要求 \(2^P-1\) 的位数,注意到 \(2^P\) 的末位一定是 \(2,4,6,8\) 之一,减一后不会发生退位导致位数减少。因此,\(2^P-1\) 的位数等于 \(2^P\) 的位数。

\(\lfloor \log_{10}(2^P) \rfloor + 1 = \lfloor P \times \log_{10}(2) \rfloor + 1\),使用 cmath 库中的 log10 函数即可直接计算。

使用高精度乘法配合快速幂计算 \(2^P\) 的最后 500 位,在进行高精度乘法时,任何超过 500 位的计算结果都可以直接丢弃,这大大减少了计算量。个位不可能是 0,所以 1 只需要简单地将个位减 1 即可,无需处理借位。

时间复杂度为 \(O(L^2 \log P)\),其中 \(L=500\) 是保留的位数。

参考代码
#include <cstdio>
#include <cmath>

// 题目要求输出最后 500 位,定义数组长度为 500
// 数组使用倒序存储,res[0] 是个位
const int LEN = 500;
int res[LEN], base[LEN], tmp[LEN]; // res: 结果, base: 基数, tmp: 乘法临时结果

// 高精度乘法函数,计算 t = a * b
// 由于只关注最后 500 位,超过部分直接丢弃
void mul(const int a[], const int b[], int t[]) {
    // 清空临时数组
    for (int i = 0; i < LEN; i++) t[i] = 0;
    
    // 逐位相乘,模拟竖式乘法
    for (int i = 0; i < LEN; i++) {
        // 超过 500 位的部分不需要计算
        for (int j = 0; j < LEN - i; j++) {
            t[i + j] += a[i] * b[j];
        }
    }
    
    // 处理进位
    for (int i = 0; i < LEN; i++) {
        if (t[i] >= 10) {
            // 如果不是最高有效位(第500位),则进位
            if (i + 1 < LEN) {
                t[i + 1] += t[i] / 10;
            }
            t[i] %= 10;
        }
    }
}

// 辅助函数:数组复制 d = s
void copy(int d[], int s[]) {
    for (int i = 0; i < LEN; i++) d[i] = s[i];
}

int main()
{
    int p; 
    scanf("%d", &p);
    
    // 1. 计算位数
    // 2^P - 1 的位数与 2^P 相同。(因为 2^P 不会是 100...0)
    // 位数公式:floor(P * log10(2)) + 1
    printf("%d\n", int(p * log10(2)) + 1);
    
    // 2. 高精度快速幂计算 2^P 的最后 500 位
    // 初始化结果为 1
    res[0] = 1; 
    // 初始化底数为 2
    base[0] = 2;
    
    // 快速幂模板
    while (p > 0) {
        if (p & 1) {
            mul(res, base, tmp); // res = res * base
            copy(res, tmp);
        }
        mul(base, base, tmp);    // base = base * base
        copy(base, tmp);
        p >>= 1;
    }
    
    // 3. 计算 2^P - 1
    // 因为 2^P (P>=1) 末位必定是偶数 (2,4,6,8),所以减 1 不需要借位,直接个位减 1
    res[0]--;
    
    // 4. 输出最后 500 位
    // 每行 50 个数字,共 10 行
    for (int i = LEN - 1; i >= 0; i--) {
        printf("%d", res[i]);
        // 每输出 50 个字符换行
        if (i % 50 == 0) printf("\n");
    }
    return 0;
}
posted @ 2025-09-06 00:03  RonChen  阅读(41)  评论(0)    收藏  举报