洛谷 CF1097H. Mateusz and an Infinite Sequence

首先观察到生成序列的方式非常像直接复制若干遍,只是加上了一个权值而已。
如何用尽量少的东西去刻画所有的序列,还要方便查询?

我们只考虑复制的其中一段,而不是整个序列。
它与第几轮复制有关,其次还和整段加上的数有关。
因此设 $dp_{i,j}$ 表示当前是第 $i$ 轮,整段加上了 $j$ 的序列信息。

既然是 dp,就要支持合并状态信息。
考虑记录什么信息才能简单且有效地合并。
那当然是对每个状态记录:当前整个序列长度、当前序列答案、前缀和后缀 $\geq b_i$ 的信息(bitset)。
在合并更新答案的时候只要把 bitset 对应的位与运算即可。

对于查询,只需要从大到小枚举 $d^i$,并不断合并 $f_{i,\dots}$ 即可,第二维根据 $gen$ 数组求出。

#include <bits/stdc++.h>
using namespace std;
const int N = 3e4 + 15;
int d, m, n, gen[25], b[N];
long long qmi[66];
bitset<N> qwq;
long long L, R;

bitset<N> Pre(int x) { return (qwq << n - 1 - x) & qwq; }   // 前缀 x 个是 1 的 bit
bitset<N> Suf(int x) { return (qwq >> n - 1 - x) & qwq; }   // 后缀 x 个是 1 的 bit

struct node {
    bitset<N> pre, suf; //存储 n-1 位信息
    long long len, cnt;
    void clr() { pre.reset(), suf.reset(), len = 0, cnt = 0ll; }
} dp[66][66];
node merge(node a, node b) {
    if (!a.len) return b;   //qwq
    node res;
    res.pre = a.pre, res.suf = b.suf, res.len = a.len + b.len, res.cnt = a.cnt + b.cnt;
    //去除首尾多余部分
    if (a.len < n - 1) res.pre &= (b.pre >> a.len) | Pre(a.len);
    if (b.len < n - 1) res.suf &= (a.suf << b.len) | Suf(b.len);
    if (res.len > n - 1) {
        bitset<N> bit = a.suf & b.pre;
        // 去除多余的位
        if (a.len < n - 1) bit &= Suf(a.len);
        if (b.len < n - 1) bit &= Pre(b.len);
        res.cnt += bit.count();
    }
    return res;
}

void init(long long lim) {
    for (int i = 1; i <  n; i++) qwq.set(i);
    qmi[0] = 1; for (int i = 1; i <= 64; i++) qmi[i] = qmi[i - 1] * 1ll * d;
    int dep = 0; while (qmi[dep + 1] < lim) dep++;
    for (int i = 0; i < m; i++) {
        dp[0][i].len = 1;
        if (n == 1) dp[0][i].cnt = (i <= b[1]);
        else for (int j = 1; j <= n; j++) dp[0][i].pre[j - 1] = (i <= b[j]), dp[0][i].suf[j] = (i <= b[j]);
    }
    for (int i = 1; i <= dep; i++)
        for (int j = 0; j < m; j++)
            for (int k = 1; k <= d; k++) dp[i][j] = merge(dp[i][j], dp[i - 1][(j + gen[k]) % m]);
}
long long solve(long long p) {
    int dep = 0; while (qmi[dep + 1] < p) dep++;
    node ans; ans.clr();
    int sum = 0;
    for (int i = dep; i >= 0; i--)
        for (int j = 1; j <= d; j++)
            if (p >= qmi[i]) p -= qmi[i], ans = merge(ans, dp[i][(sum + gen[j]) % m]);
            else { (sum += gen[j]) %= m; break; }
    return ans.cnt;
}

int main() {
    scanf("%d%d", &d, &m);
    for (int i = 1; i <= d; i++) scanf("%d", &gen[i]);
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) scanf("%d", &b[i]);
    scanf("%lld%lld", &L, &R);
    init(R);
    printf("%lld\n", solve(R) - solve(L + n - 2) );
    return 0;
}
posted @ 2025-03-29 10:43  Conan15  阅读(4)  评论(0)    收藏  举报  来源