洛谷 CF1097H. Mateusz and an Infinite Sequence
首先观察到生成序列的方式非常像直接复制若干遍,只是加上了一个权值而已。
如何用尽量少的东西去刻画所有的序列,还要方便查询?
我们只考虑复制的其中一段,而不是整个序列。
它与第几轮复制有关,其次还和整段加上的数有关。
因此设 $dp_{i,j}$ 表示当前是第 $i$ 轮,整段加上了 $j$ 的序列信息。
既然是 dp,就要支持合并状态信息。
考虑记录什么信息才能简单且有效地合并。
那当然是对每个状态记录:当前整个序列长度、当前序列答案、前缀和后缀 $\geq b_i$ 的信息(bitset)。
在合并更新答案的时候只要把 bitset 对应的位与运算即可。
对于查询,只需要从大到小枚举 $d^i$,并不断合并 $f_{i,\dots}$ 即可,第二维根据 $gen$ 数组求出。
#include <bits/stdc++.h>
using namespace std;
const int N = 3e4 + 15;
int d, m, n, gen[25], b[N];
long long qmi[66];
bitset<N> qwq;
long long L, R;
bitset<N> Pre(int x) { return (qwq << n - 1 - x) & qwq; } // 前缀 x 个是 1 的 bit
bitset<N> Suf(int x) { return (qwq >> n - 1 - x) & qwq; } // 后缀 x 个是 1 的 bit
struct node {
bitset<N> pre, suf; //存储 n-1 位信息
long long len, cnt;
void clr() { pre.reset(), suf.reset(), len = 0, cnt = 0ll; }
} dp[66][66];
node merge(node a, node b) {
if (!a.len) return b; //qwq
node res;
res.pre = a.pre, res.suf = b.suf, res.len = a.len + b.len, res.cnt = a.cnt + b.cnt;
//去除首尾多余部分
if (a.len < n - 1) res.pre &= (b.pre >> a.len) | Pre(a.len);
if (b.len < n - 1) res.suf &= (a.suf << b.len) | Suf(b.len);
if (res.len > n - 1) {
bitset<N> bit = a.suf & b.pre;
// 去除多余的位
if (a.len < n - 1) bit &= Suf(a.len);
if (b.len < n - 1) bit &= Pre(b.len);
res.cnt += bit.count();
}
return res;
}
void init(long long lim) {
for (int i = 1; i < n; i++) qwq.set(i);
qmi[0] = 1; for (int i = 1; i <= 64; i++) qmi[i] = qmi[i - 1] * 1ll * d;
int dep = 0; while (qmi[dep + 1] < lim) dep++;
for (int i = 0; i < m; i++) {
dp[0][i].len = 1;
if (n == 1) dp[0][i].cnt = (i <= b[1]);
else for (int j = 1; j <= n; j++) dp[0][i].pre[j - 1] = (i <= b[j]), dp[0][i].suf[j] = (i <= b[j]);
}
for (int i = 1; i <= dep; i++)
for (int j = 0; j < m; j++)
for (int k = 1; k <= d; k++) dp[i][j] = merge(dp[i][j], dp[i - 1][(j + gen[k]) % m]);
}
long long solve(long long p) {
int dep = 0; while (qmi[dep + 1] < p) dep++;
node ans; ans.clr();
int sum = 0;
for (int i = dep; i >= 0; i--)
for (int j = 1; j <= d; j++)
if (p >= qmi[i]) p -= qmi[i], ans = merge(ans, dp[i][(sum + gen[j]) % m]);
else { (sum += gen[j]) %= m; break; }
return ans.cnt;
}
int main() {
scanf("%d%d", &d, &m);
for (int i = 1; i <= d; i++) scanf("%d", &gen[i]);
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%d", &b[i]);
scanf("%lld%lld", &L, &R);
init(R);
printf("%lld\n", solve(R) - solve(L + n - 2) );
return 0;
}

浙公网安备 33010602011771号