AT AGC013D Piling Up

首先因为每一步中的步骤 1 和步骤 2,3 看着并没有什么关联,所以可以把每一步拆成两步:步骤 1,步骤 2,3。

因为是对一个序列计数,且这个序列并没有什么明显的性质,于是尝试找一个较为简单的判定条件

一个想法是复原,因为一个终止状态要能够合法则必须要存在起始状态使其能被操作得到。
当然也可以逆向,考虑由初始状态构造至该终止状态。

同时发现因为每一步被取走的块数量是固定的,初始的块数也是固定的。
这说明每一步的总块数都是固定的,记为 \(s_i\)这也说明对于每个时刻只需要关心红块的数量
于是可以考虑对于每一步都维护一个集合 \(c_i\) 代表 \(1\sim i\) 步均合法,在第 \(i\) 步可能的红块的数量的集合。
那么就可以考虑进入第 \(i + 1\) 步并递推 \(c_{i + 1}\) 了:

  • \(i + 1\) 步对应步骤 1:
    考虑 \(x\in c_i\),如果这里选的是红色,那么只要 \(x\ge 1\),就可以把 \(x - 1\) 加入 \(c_{i + 1}\);若为蓝色,则 \(s_i - x\ge 1\) 就可以把 \(x\) 加入 \(c_{i + 1}\)
  • \(i + 1\) 步对应步骤 2,3:
    此时因为前面放进来了一个红一个蓝,所以一定能选出来。
    考虑 \(x\in c_i\),若选的红色,那么把 \(x\) 放进 \(c_{i + 1}\);若选的蓝色,那么把 \(x + 1\) 放进 \(c_{i + 1}\)

对这个 \(c_i\) 的递推进行一些基本的观察,发现每一次都是形如整体偏移加上一些边界数字(极大值极小值)的去除,并且并不会加入重复数字进入集合
再结合上初始值 \(c_i = \{0, 1, \cdots, n - 1, n\}\)发现能够用 \({[l_i, r_i]}\) 表示出 \({c_i = \{l_i, l_i + 1, \cdots, r_i - 1, r_i\}}\),而且这样递推关系也就非常好处理了。

于是一个想法就是直接表示成 \(f_{i, l, r}\) 表示考虑了 \(1\sim i\),此时 \([l_i, r_i] = [l, r]\) 对应的 \(1\sim i\) 的序列的数量,转移仿照上文。
但是这样只能做到 \(\mathcal{O}(mn^2)\),还需要进一步优化。

进一步思考,那么到了终止状态,若合法的也一定会是一个区间形式。
而每个方案贡献都为 \({1}\),所以考虑点边容斥
即令 \(f_{i, j}\) 表示考虑了 \(1\sim i\),此时 \(j\in [l_i, r_i]\)\(j\) 是合法的),对应的 \(1\sim i\) 的序列的数量;令 \(g_{i, j}\) 表示考虑了 \(1\sim i\),此时 \(j, j + 1\in [l_i, r_i]\)\([j, j + 1]\) 这条线段是合法的),对应的 \(1\sim i\) 的序列数量。
转移式依旧可以参照上文的分类讨论,或者可以见实现。

那么点边容斥后,就可以表示出 \(\operatorname{ans} = \sum\limits_{i = 0}^n f_{m, i} - \sum\limits_{i = 0}^{n - 1}g_{m, i}\) 了。

时间复杂度 \(\mathcal{O}(nm)\)

#include<bits/stdc++.h>
constexpr int mod = 1e9 + 7;
inline void add(int &x, const int y) { (x += y) >= mod && (x -= mod); }
constexpr int maxm = 6e3 + 10, maxs = 3e3 + 10;
int f[maxm][maxs], g[maxm][maxs];
int main() {
    int n, m;
    scanf("%d%d", &n, &m), m *= 2;
    for (int i = 0; i <= n; i++) f[0][i] = 1;
    for (int i = 0; i < n; i++) g[0][i] = 1;
    int s = n;
    for (int t = 1; t <= m; t++) {
        if (t % 2 == 1) {
            for (int x = 0; x <= s; x++) {
                int y = s - x;
                if (x) add(f[t][x - 1], f[t - 1][x]);
                if (y) add(f[t][x], f[t - 1][x]);
            }
            for (int x = 0; x < s; x++) {
                int y = s - x - 1;
                if (x) add(g[t][x - 1], g[t - 1][x]);
                if (y) add(g[t][x], g[t - 1][x]);
            }
        } else {
            for (int x = 0; x <= s; x++) {
                int y = s - x;
                add(f[t][x], f[t - 1][x]);
                add(f[t][x + 1], f[t - 1][x]);
            }
            for (int x = 0; x < s; x++) {
                int y = s - x - 1;
                add(g[t][x], g[t - 1][x]);
                add(g[t][x + 1], g[t - 1][x]);
            }
            s += 2;
        }
        s--;
    }
    int ans = 0;
    for (int i = 0; i <= s; i++) add(ans, f[m][i]);
    for (int i = 0; i < s; i++) add(ans, mod - g[m][i]);
    printf("%d\n", ans);
    return 0;
}
posted @ 2025-04-09 18:42  rizynvu  阅读(41)  评论(1)    收藏  举报