競プロ典型 90 問-难题

005倍增优化dp

题目大意(自己总结
只用数字 c1​,c2​,…,cK​ 可以构造出多少个 N 位正整数是 B 的倍数? 求除以 109+7 的余数。

  • $1 \leq K \leq 9$
  • $1 \leq c_1 \lt c_2 \lt \cdots \lt c_K \leq 9$
  • $1 \leq N \leq 10^{18}$
  • $2 \leq B \leq 1000$
    题目主要实现思路
    遇到这种凑倍数相关的题目,优先考虑dp[ i j ] 表示前i个数可以组成的数模上b的值,因此可得状态转移方程dp i (j * 10 + a[ k ])%b =dp i-1 j ,此时有三层循环N * B * K,因为N很大,所以考虑矩阵快速幂,可降到 O (B³×logN) ,引入一种新方法倍增法根据前面的递推公式dp[i + 1][(r × 10 + c[k]) % B] += dp[i][r],我们可以理解为从dp[0]开始,按dp[0]dp[1]dp[2]→…→dp[N]的顺序依次计算。但这种直接计算的方式需要 O (N) 步,效率过低。
    因此可以考虑先预处理出dp[ 1 ] dp 2 dp4 dp8 参考二进制的方式
    关键在于实现 “通过dp[i]数组和dp[j]数组,快速计算出dp[i+j]数组”。只要能实现这一点,我们就能像计算 3ⁿ那样,通过倍增法在 O (logN) 步内求出dp[N]数组。

这里的dp[i+j]表示 i+j 位整数对应的动态规划数组。我们可以将 i+j 位整数拆分为 “前 i 位” 和 “后 j 位” 两部分来分析:

  • 设前 i 位整数除以 B 的余数为 p(对应的数量为dp[i][p]种);
  • 设后 j 位整数除以 B 的余数为 q(对应的数量为dp[j][q]种)。

那么,由这两部分组成的 i+j 位整数除以 B 的余数为:(p × 10ʲ + q) % B

由此可推导出dp[i+j]的递推公式:dp[i + j][(p × tⱼ + q) % B] += dp[i][p] × dp[j][q]

其中,tⱼ表示 10ʲ除以 B 的余数。这正是 “通过dp[i]dp[j]计算dp[i+j]” 的递推公式,且该递推的时间复杂度为 O (B²)。

#include <bits/stdc++.h>

#define int long long

using namespace std;

const int N = 1e6 + 10;

const int mod = 1e9 + 7;

void solve()

{

    int n, b, k;

    cin >> n >> b >> k;

    vector<int> a(k, 0);

    for (int i = 0; i < k; i++)

    {

        cin >> a[i];

    }

    auto mul = [&](vector<int> &dpi, vector<int> &dpj, int powv) -> vector<int>

    {

        vector<int> res(b, 0);

        for (int p = 0; p < b; p++)

        {

            for (int q = 0; q < b; q++)

            {

                res[(p * powv + q) % b] += (dpi[p] * dpj[q]) % mod;

                res[(p * powv + q) % b] = res[(p * powv + q) % b] % mod;

            }

        }

        return res;

    };

    vector<int> tenpow(100, 0);

    tenpow[0] = 10;

    for (int i = 1; i < 100; i++)

    {

        tenpow[i] = (tenpow[i - 1] * tenpow[i - 1]) % b;

    }

    vector<vector<int>> fastdp(100, vector<int>(b, 0));

    for (int i = 0; i < k; i++)

    {

        fastdp[0][a[i] % b] += 1;

    }

    for (int i = 1; i < 100; i++)

    {

        fastdp[i] = mul(fastdp[i - 1], fastdp[i - 1], tenpow[i - 1]);

    }

    vector<int> res(b, 0);

    res[0] = 1;

    for (int i = 0; i < 63; i++)

    {

        if ((n >> i) & 1)

        {

            res = mul(res, fastdp[i], tenpow[i]);

        }

    }

    cout << res[0] << '\n';

}

signed main()

{

    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);

    int T;

    T = 1;

    // cin >> T;

    while (T--)

        solve();

    return 0;

}
posted @ 2025-11-05 23:12  Jwe1  阅读(5)  评论(0)    收藏  举报