DP 杂记

DP 杂记

柱状图 DP

柱状图 DP 的显著的特点是选的数不会太多,可以很好维护选出的数满足单调性、互异性的性质,转移分为增加一个柱子和把所有柱子整体升高。

LOJ6406. 「ICPC World Finals 2018」绿宝石之岛

\(n\) 堆宝石,每堆 \(1\) 个。每轮会等概率选择 \(1\) 个宝石将其变为 \(2\) 个,求 \(d\) 轮后前 \(r\) 多的宝石堆的宝石数量和的期望。

\(n, d \le 500\)\(1 \le r \le n\)

考虑末状态,设第 \(i\) 堆多了 \(a_i\) 个宝石,对应方案数为 \(\binom{d}{a_1, a_2, \cdots, a_n} \times \prod a_i! = d!\) (后一项是每次选择哪个宝石分裂),乘上选出一组 \(\{ a_i \}\) 的方案数得到 \(\binom{d + n - 1}{d} \times d! = \frac{(n + d - 1)!}{(n - 1)!}\) 。而分裂的总方案数为 \(\frac{(n + d - 1)!}{(n - 1)!}\) ,于是每种末状态是等概率出现的,可以直接对方案计数。

\(g_{i, j}\) 表示还剩 \(i\) 堆的 \(a\) 正在增加、用了 \(j\) 个宝石的方案数,转移就是保留 \(k\) 堆增加 \(1\) ,剩下的扔掉:

\[g_{i, j} = \sum_{k = 0}^{\min(i, j)} \binom{i}{k} g_{k, j - k} \]

再设 \(f_{i, j}\) 表示,还剩 \(i\) 堆的 \(a\) 正在增加、用了 \(j\) 个宝石时所有方案前 \(r\) 大的宝石数量和的总和,类似的有:

\[f_{i, j} = \sum_{k = 0}^{\min(i, j)} \binom{i}{k} (f_{k, j - k} + g_{k, j - k} \times \min(k, r)) \]

答案即为 \(\frac{f_{n, d}}{g_{n, d}} + r\) ,因为要算上初始的一个宝石。

时间复杂度 \(O(nd \min(n, d))\)

#include <bits/stdc++.h>
using namespace std;
const int N = 5e2 + 7;

double C[N][N], f[N][N], g[N][N];

int n, d, r;

signed main() {
    scanf("%d%d%d", &n, &d, &r);

    for (int i = C[0][0] = 1; i <= n; ++i)
        for (int j = C[i][0] = 1; j <= i; ++j)
            C[i][j] = C[i - 1][j] + C[i - 1][j - 1];

    g[0][0] = 1;

    for (int i = 1; i <= n; ++i)
        for (int j = 0; j <= d; ++j)
            for (int k = 0; k <= min(i, j); ++k) {
                g[i][j] += C[i][k] * g[k][j - k];
                f[i][j] += C[i][k] * (f[k][j - k] + g[k][j - k] * min(k, r));
            }

    printf("%.9lf", f[n][d] / g[n][d] + r);
    return 0;
}

LOJ6077. 「2017 山东一轮集训 Day7」逆序对

给定 \(n, k\) ,求长度为 \(n\) 、逆序对数为 \(k\) 的排列数量模 \(10^9 + 7\)

\(n, k \le 10^5\)

考虑在 \(1 \sim i - 1\) 的排列中插入 \(i\) ,则会对逆序对产生 \([0, i - 1]\) 不等的贡献。问题转化为求 \(\sum_{i = 1}^n x_i = k\) 的解的方案数,其中 \(x_i \in [0, i - 1]\)

考虑容斥,对于钦定不满足的位置 \(i\) ,令 \(x_i \gets x_i - i\) ,这样限制均转化为 \(x_i \ge 0\) 。得到:

\[ans = \sum_{S \subseteq \{ 1, 2, \cdots, n \}} (-1)^{|S|} \times c(k - \sum_{i \in S} i) \]

其中 \(c(k)\) 表示 \(\sum_{i = 1}^n x_i = k\) 的方案数,不难用插板法算出。

\(f_{i, j}\) 表示从 \(1 \sim n\) 中选 \(i\) 个不同的数,和为 \(j\) 的方案数。转移时分类讨论:

  • 整体垫高一层:\(f_{i, j} \to f_{i, j + i}\)
  • 整体垫高一层并在末尾加上一列:\(f_{i, j} \to f_{i + 1, j + i + 1}\)
  • 防止选的数 \(>n\)\(- f_{i - 1, j - (n + 1)} \to f_{i, j}\)

时间复杂度 \(O(k \sqrt{k})\)

#include <bits/stdc++.h>
using namespace std;
const int Mod = 1e9 + 7;
const int N = 2e5 + 7, B = 5e2 + 7;

int f[B][N], fac[N], inv[N], invfac[N];

int n, k;

inline int add(int x, int y) {
    x += y;
    
    if (x >= Mod)
        x -= Mod;
    
    return x;
}

inline int dec(int x, int y) {
    x -= y;
    
    if (x < 0)
        x += Mod;
    
    return x;
}

inline int sgn(int n) {
    return n & 1 ? Mod - 1 : 1;
}

inline void prework(int n) {
    fac[0] = fac[1] = 1;
    inv[0] = inv[1] = 1;
    invfac[0] = invfac[1] = 1;
    
    for (int i = 2; i <= n; ++i) {
        fac[i] = 1ll * fac[i - 1] * i % Mod;
        inv[i] = 1ll * (Mod - Mod / i) * inv[Mod % i] % Mod;
        invfac[i] = 1ll * invfac[i - 1] * inv[i] % Mod;
    }
}

inline int C(int n, int m) {
    return m > n || m < 0 ? 0 : 1ll * fac[n] * invfac[m] % Mod * invfac[n - m] % Mod;
}

inline int calc(int m) {
    return C(m + n - 1, n - 1);
}

signed main() {
    scanf("%d%d", &n, &k);
    prework(n + k), f[0][0] = f[1][1] = 1;

    for (int i = 1; i <= n && i * (i + 1) / 2 <= k; ++i)
        for (int j = i * (i + 1) / 2; j <= k; ++j) {
            if (j > n)
                f[i][j] = dec(f[i][j], f[i - 1][j - (n + 1)]);

            if (j + i <= k)
                f[i][j + i] = add(f[i][j + i], f[i][j]);

            if (j + i + 1 <= k)
                f[i + 1][j + i + 1] = add(f[i + 1][j + i + 1], f[i][j]);
        }

    int ans = 0;

    for (int i = 0; i <= n && i * (i + 1) / 2 <= k; ++i)
        for (int j = i * (i + 1) / 2; j <= k; ++j)
            ans = add(ans, 1ll * sgn(i) * f[i][j] % Mod * calc(k - j) % Mod);

    printf("%d", ans);
    return 0;
}

P8340 [AHOI2022] 山河重整

求有多少个 \(S \subseteq \{ 1, 2, \cdots, n \}\) 满足对于 \(k = 1, 2, \cdots, n\) ,均存在 \(T \subseteq S\) 满足 \(\sum_{x \in T} x = k\)

\(n \le 5 \times 10^5\)

显然直接做 01 背包判定 \(S\) 的合法性并不好推广,考虑直接找判定的充要条件:对于任意 \(i \in [1, n]\) ,选取的 \(\le i\) 的数的和 \(\ge i\) ,这个条件的充要性可以归纳证明。

\(f_{i, j}\) 表示考虑了前 \(i\) 个数,已经覆盖了 \([1, j]\) 的前缀的方案数。转移时考虑第 \(i\) 个数是否选取,若选取则 \(j \to j + i\) ,合法状态需要满足 \(j \ge i\) ,时间复杂度 \(O(n^2)\)

考虑容斥,找到第一个不满足条件的位置 \(i\) 开始转移,并乘上 \(-1\) 的系数。此时 \(\le i - 1\) 的数的和为 \(i - 1\) ,然后 \(i\) 不能选,后面的数随意。设 \(f_i\) 表示第一次使得选出的 \([1, i]\) 的数和为 \(i\) 的方案数,答案即为:

\[2^n - \sum_{i = 0}^{n - 1} f_i \times 2^{n - i - 1} \]

\(f_i\) 同样考虑容斥。若不考虑第一次不满足条件的限制,则直接做柱状图 DP 即可,具体转移就是倒序枚举剩下 \(i\) 个未确定的柱子:

  • \(i\) 个柱子的高度要和其它柱子区分开来,整体增高 \(1\) 即可,具体就是降序枚举 \(j\) 并令 \(f_j = f_{j - i}\)
  • 然后需要考虑一共只有 \(i\) 个高度为 \(1\) 的柱子的情况,即令 \(f_i = 1\)
  • 最后考虑这 \(i\) 个柱子垫高若干层,升序枚举 \(j\) 然后做转移 \(f_j = f_j + f_{j - i}\)

接下来考虑去重,若当前需要用 \(f_j\) 去重 \(f_i\) ,系数就是 \([j + 2, i]\) 内选的数的和为 \(i - j\) 的方案数。

注意到系数非 \(0\)\(j\) 一定满足 \(j \le \frac{i}{2}\) ,于是考虑倍增,每次处理左区间对右区间的影响,每层可以做一次柱状图 DP 转移。

具体就是考虑先将每个 \(f_j\) 视为已经对和贡献了 \(j\) ,那么只要满足和为 \(i\) 即可。第一步和第三步是一样的,而第二步则是一共只有 \(i\) 个高度均为 \(j + 2\) 的柱子的情况,注意这里的转移系数为 \(f\)

时间复杂度 \(O(n \sqrt{n})\)

#include <bits/stdc++.h>
using namespace std;
const int N = 5e5 + 7;

int f[N], g[N];

int n, Mod;

inline int add(int x, int y) {
    x += y;
    
    if (x >= Mod)
        x -= Mod;
    
    return x;
}

inline int dec(int x, int y) {
    x -= y;
    
    if (x < 0)
        x += Mod;
    
    return x;
}

void solve(int n) {
    if (n <= 1)
        return;

    solve(n / 2), memset(g, 0, sizeof(int) * (n + 1));

    for (int i = sqrt(n * 2) + 1; i; --i) {
        for (int j = n; j >= i; --j)
            g[j] = g[j - i];

        for (int j = 0; j + i * (j + 2) <= n; ++j)
            g[j + i * (j + 2)] = add(g[j + i * (j + 2)], f[j]);

        for (int j = i; j <= n; ++j)
            g[j] = add(g[j], g[j - i]);
    }

    for (int i = n / 2 + 1; i <= n; ++i)
        f[i] = dec(f[i], g[i]);
}

signed main() {
    scanf("%d%d", &n, &Mod);

    for (int i = sqrt(n * 2) + 1; i; --i) {
        for (int j = n; j >= i; --j)
            f[j] = f[j - i];

        f[i] = 1;

        for (int j = i; j <= n; ++j)
            f[j] = add(f[j], f[j - i]);
    }

    f[0] = 1, solve(n);
    int res = 0, pw = 1;

    for (int i = 0; i < n; ++i)
        res = add(res, 1ll * f[n - i - 1] * pw % Mod), pw = 2ll * pw % Mod;

    printf("%d", dec(pw, res));
    return 0;
}

[ARC107D] Number of Multisets

给定 \(n, k\) ,求大小为 \(n\) 、和为 \(k\) 、每个元素均可以表示为 \(2^{-i}\) 的可重集数量。

\(k \le n \le 3000\)

考虑将选数的过程转化为:初始有 \(n\)\(1\) ,每次选择一个集合将其减半。

为了避免算重,考虑柱状图 DP,每次可以加入一个 \(1\) 表示新增一个柱子,和减半表示指数减少 \(1\) (垫高一层)。

\(f_{i, j}\) 表示用了 \(i\) 个数、和为 \(j\) 的方案数,则:

  • 新增一个柱子:\(f_{i, j} \gets f_{i - 1, j - 1}\)
  • 垫高一层:\(f_{i, j} \gets f_{i, 2j}\) ,转移顺序为完全背包。
#include <bits/stdc++.h>
using namespace std;
const int Mod = 998244353;
const int N = 3e3 + 7;

int f[N][N];

int n, k;

inline int add(int x, int y) {
    x += y;
    
    if (x >= Mod)
        x -= Mod;
    
    return x;
}

signed main() {
    scanf("%d%d", &n, &k);
    f[0][0] = 1;

    for (int i = 1; i <= n; ++i) 
        for (int j = i; j; --j)
            f[i][j] = add(f[i - 1][j - 1], j * 2 <= i ? f[i][j * 2] : 0);

    printf("%d", f[n][k]);
    return 0;
}

CF1225G To Make 1

给定 \(a_{1 \sim n}\) 和常数 \(k\) ,每次可以选出 \(x, y\) 两个数合并为 \(f(x + y)\) ,其中 \(f(a)\) 表示 \(a\) 不断除以 \(k\) 直至无法整除时所得的数。

构造一组方案使得最后剩下 \(1\) ,或报告无解。

\(n \le 16\)\(\sum a_i \le 2000\)\(k \nmid a_i\)

考虑每个 \(a_i\) 被除以 \(k\) 的次数 \(b_i\) ,则充要条件就是存在 \(b_{1 \sim n}\) 满足 \(\sum a_i \times k^{-b_i} = 1\)

证明:必要性显然,考虑证明充分性。令 \(B = \max b_i\) ,则 \(k^B = \sum a_i \times k^{B - b_i}\) 。对于 \(b_i < B\) 的项,其一定是 \(k\) 的倍数。对于 \(b_i = B\) 的项,则 \(k \mid \sum [b_i = 0] a_i\) ,而 \(k \nmid a_i\) ,因此这样的项至少有两个,把它们加在一起即可让指数至少消掉 \(1\) 。由此可以归纳证明充分性。

接下来考虑从大到小对 \(b\) DP,设 \(f_{s, p}\) 表示集合 \(s\) 的和为 \(p\) 是否可行,则:

  • 向集合中加入 \(b_i = 0\) 的一个 \(a_i\)\(f_{s, p} \to f_{s \cup \{ i \}, p + a_i}\)
  • 把集合中的 \(b_i\) 全部增加 \(1\)\(f_{s, p} \to f_{s, \frac{p}{k}}\) ,其中 \(k \mid p\) ,转移顺序为完全背包。

不难用 bitset 优化到 \(O(\frac{n 2^n \sum a_i}{\omega})\)

#include <bits/stdc++.h>
using namespace std;
const int N = 16, S = 2e3 + 7;

bitset<S> f[1 << N];

int a[N], b[N];

int n, k, m;

signed main() {
    scanf("%d%d", &n, &k);

    for (int i = 0; i < n; ++i)
        scanf("%d", a + i), m += a[i];

    f[0].set(0);

    for (int s = 1; s < (1 << n); ++s) {
        for (int i = 0; i < n; ++i)
            if (s >> i & 1)
                f[s] |= f[s ^ (1 << i)] << a[i];

        for (int i = m / k; i; --i)
            f[s][i] = f[s][i] | f[s][i * k];
    }

    if (!f[(1 << n) - 1][1])
        return puts("NO"), 0;

    puts("YES");

    for (int s = (1 << n) - 1, p = 1; s;) {
        if (p * k <= m && f[s][p * k]) {
            for (int i = 0; i < n; ++i)
                if (s >> i & 1)
                    ++b[i];

            p *= k;
        } else {
            for (int i = 0; i < n; ++i)
                if ((s >> i & 1) && a[i] <= p && f[s ^ (1 << i)][p - a[i]]) {
                    s ^= 1 << i, p -= a[i];
                    break;
                }
        }
    }

    priority_queue<pair<int, int> > q;

    for (int i = 0; i < n; ++i)
        q.emplace(b[i], a[i]);

    while (q.size() >= 2) {
        auto x = q.top();
        q.pop();
        auto y = q.top();
        q.pop();
        printf("%d %d\n", x.second, y.second);
        int val = x.second + y.second, cnt = x.first;

        while (!(val % k))
            --cnt, val /= k;

        q.emplace(cnt, val);
    }

    return 0;
}

连续段 DP

在一类序列计数的问题中,状态的转移可能与相邻的已插入元素的值紧密相关,只有知道其值才能求解。而如果此时在只考虑往序列两端插入的情况下,问题将变得容易解决的时候,就可以利用连续段 DP。

此类问题一半形如:求满足某些限制的 \(n\) 个元素的排列数量。

\(f_{i, j}\) 表示前 \(i\) 个元素,已经形成 \(j\) 个连续段的方案数(视题目限制不同可能会在此基础上记录更多信息,此处不考虑其余限制)。分讨转移:

  • 将被插入元素用于新建连续段:\(f_{i, j} \times (j + 1) \to f_{i + 1, j + 1}\)
  • 插入元素至已有连续段的两端:\(f_{i, j} \times 2j \to f_{i + 1, j}\)
    • 部分题目左、右侧插入的状态转移可能不同,需分类讨论。
  • 将被插入元素用于合并两连续段:\(f_{i, j} \times (j - 1) \to f_{i + 1, j - 1}\)

通过三种转移方式所得到的结果,一定与原序列成一一对应,因为不同的插入方式本质上决定着元素在原序列中的位置。

CF1515E Phoenix and Computers

\(n\) 盏灯,需要依次开启所有灯。每次可以手动开启一盏关闭的灯,任意时刻若第 \(i - 1\) 盏灯和第 \(i + 1\) 盏灯同时开启,则第 \(i\) 盏灯自动开启。求开灯方案数。

\(n \le 400\)

\(f_{i, j}\) 表示考虑了 \(i\) 盏灯,组成了 \(j\) 个亮灯连续段的方案数,分讨转移:

  • 作为新的连续段插入:\(f_{i, j} \times (j + 1) \to f_{i + 1, j + 1}\)
  • 在已有连续段的一端插入新元素:
    • 直接在端点紧邻的位置插入: \(f_{i, j} \times 2j \to f_{i + 1, j}\)
    • 在与端点相隔一格的位置插入: \(f_{i, j} \times 2j \to f_{i + 2, j}\)
  • 合并两连续段:
    • 在两连续段间插入两盏灯并打开其中一个:\(f_{i, j} \times 2(j - 1) \to f_{i + 2, j - 1}\)
    • 插入三盏灯并打开中间那个:\(f_{i, j} \times (j - 1) \to f_{i + 3, j - 1}\)

答案即为 \(f_{n, 1}\) ,时间复杂度 \(O(n^2)\)

#include <bits/stdc++.h>
using namespace std;
const int N = 4e2 + 7;

int f[N][N];

int n, Mod;

inline int add(int x, int y) {
    x += y;
    
    if (x >= Mod)
        x -= Mod;
    
    return x;
}

inline int dec(int x, int y) {
    x -= y;
    
    if (x < 0)
        x += Mod;
    
    return x;
}

signed main() {
    scanf("%d%d", &n, &Mod);
    f[1][1] = 1;

    for (int i = 1; i < n; ++i)
        for (int j = 1; j <= i; ++j) {
            f[i + 1][j + 1] = add(f[i + 1][j + 1], 1ll * f[i][j] * (j + 1) % Mod);

            f[i + 1][j] = add(f[i + 1][j], 2ll * f[i][j] * j % Mod);
            f[i + 2][j] = add(f[i + 2][j], 2ll * f[i][j] * j % Mod);

            if (j) {
                f[i + 2][j - 1] = add(f[i + 2][j - 1], 2ll * f[i][j] * (j - 1) % Mod);
                f[i + 3][j - 1] = add(f[i + 3][j - 1], 1ll * f[i][j] * (j - 1) % Mod);
            }
        }

    printf("%d", f[n][1]);
    return 0;
}

P7967 [COCI2021-2022#2] Magneti

给定 \(n\) 个磁铁和 \(l\) 个相邻空位,每个空位可放置一个磁铁。所有 \(n\) 个磁铁都必须被放置。每个磁铁可以吸引距离小于 \(r_i\) 的其它磁铁。求所有磁铁互不吸引的方案数。

\(n \le 50\)\(l \le 10^4\)

记覆盖区间表示磁铁经过放置最左端、最右端的磁铁的位置组成的区间。

先做一个转化:考虑对于一种磁铁的排列顺序 \(p\) ,求出其极小的覆盖区间长度 \(D(p)\) ,则方案数为 \(\binom{l - D(p) + n}{n}\)

考虑按 \(r\) 从小到大插入,这样限制只要考虑后插入的 \(r\) 。设 \(f_{i, j, k}\) 表示插入了前 \(i\) 个元素,形成 \(j\) 个连续段,目前覆盖区间长度为 \(k\) 的方案数。

  • 新建连续段:\(f_{i, j, k} \times (j + 1) \to f_{i + 1, j + 1, k + 1}\)
  • 插入已有连续段一端: \(f_{i, j, k} \times 2j \to f_{i + 1, j, k + r_{i + 1}}\) .
  • 合并两连续段:\(f_{i, j, k} \times (j - 1) \to f_{i + 1, j - 1, k + 2 r_{i + 1} - 1}\)

答案即为 \(\sum_{i = n}^l \binom{l - i + n}{n} f_{n, 1, i}\)

#include <bits/stdc++.h>
using namespace std;
const int Mod = 1e9 + 7;
const int N = 5e1 + 7, L = 1e4 + 7;

int f[N][N][L];
int fac[L], inv[L], invfac[L];
int r[N];

int n, l;

inline int add(int x, int y) {
    x += y;
    
    if (x >= Mod)
        x -= Mod;
    
    return x;
}

inline int dec(int x, int y) {
    x -= y;
    
    if (x < 0)
        x += Mod;
    
    return x;
}

inline void prework(int n) {
    fac[0] = fac[1] = 1;
    inv[0] = inv[1] = 1;
    invfac[0] = invfac[1] = 1;
    
    for (int i = 2; i <= n; ++i) {
        fac[i] = 1ll * fac[i - 1] * i % Mod;
        inv[i] = 1ll * (Mod - Mod / i) * inv[Mod % i] % Mod;
        invfac[i] = 1ll * invfac[i - 1] * inv[i] % Mod;
    }
}

inline int C(int n, int m) {
    return m > n ? 0 : 1ll * fac[n] * invfac[m] % Mod * invfac[n - m] % Mod;
}

signed main() {
    scanf("%d%d", &n, &l);
    prework(l);

    for (int i = 1; i <= n; ++i)
        scanf("%d", r + i);

    sort(r + 1, r + n + 1);
    f[1][1][1] = 1;

    for (int i = 1; i < n; ++i)
        for (int j = 1; j <= i; ++j)
            for (int k = 1; k <= l; ++k)
                if (f[i][j][k]) {
                    if (k + 1 <= l)
                        f[i + 1][j + 1][k + 1] = add(f[i + 1][j + 1][k + 1], 
                            1ll * f[i][j][k] * (j + 1) % Mod);

                    if (k + r[i + 1] <= l)
                        f[i + 1][j][k + r[i + 1]] = add(f[i + 1][j][k + r[i + 1]], 
                            2ll * f[i][j][k] * j % Mod);

                    if (k + r[i + 1] * 2 - 1 <= l)
                        f[i + 1][j - 1][k + r[i + 1] * 2 - 1] = add(f[i + 1][j - 1][k + r[i + 1] * 2 - 1], 
                            1ll * f[i][j][k] * (j - 1) % Mod);
                }

    int ans = 0;

    for (int i = 0; i <= l; ++i)
        ans = add(ans, 1ll * C(l - i + n, n) * f[n][1][i] % Mod);

    printf("%d", ans);
    return 0;
}

P5999 [CEOI2016] kangaroo

求有多少首尾为 \(s, t\) 、长度为 \(n\) 的波浪形排列。

\(n \le 2000\)

\(f_{i, j}\) 表示 \(1 \sim i\) 分出 \(j\) 个连续段的方案数,注意本题不能在一个段的两端放,这是因为之后肯定会有更大的接在一侧,形成长度 \(\geq 3\) 的单调段,这是非法的。

特殊处理一下 \(i = s, t\) 的情况,此时可以接在开头/结尾,此时要么放在一个段的一段,要么新建一个连续段。

时间复杂度 \(O(n^2)\)

#include <bits/stdc++.h>
using namespace std;
const int Mod = 1e9 + 7;
const int N = 2e3 + 7;

int f[N][N];

int n, s, t;

inline int add(int x, int y) {
    x += y;
    
    if (x >= Mod)
        x -= Mod;
    
    return x;
}

inline int dec(int x, int y) {
    x -= y;
    
    if (x < 0)
        x += Mod;
    
    return x;
}

signed main() {
    scanf("%d%d%d", &n, &s, &t);
    f[1][1] = 1;

    for (int i = 1; i < n; ++i)
        for (int j = 1; j <= i; ++j) {
            if (i + 1 != s && i + 1 != t) {
                f[i + 1][j - 1] = add(f[i + 1][j - 1], 1ll * f[i][j] * (j - 1) % Mod);
                f[i + 1][j + 1] = add(f[i + 1][j + 1], 1ll * f[i][j] * (j + 1 - (i + 1 > s) - (i + 1 > t)) % Mod);
            } else {
                f[i + 1][j] = add(f[i + 1][j], f[i][j]);
                f[i + 1][j + 1] = add(f[i + 1][j + 1], f[i][j]);
            }
        }

    printf("%d", f[n][1]);
    return 0;
}

P9197 [JOI Open 2016] 摩天大楼

给出 \(a_{1 \sim n}\) ,求有多少重排序列的方案使得相邻两数差的绝对值之和 \(\le L\)

\(n \le 100\)\(a_i, L \le 1000\)

注意到绝对值求和的形式,因此考虑将权值排序从而去掉绝对值。

把重排后的序列放到坐标系上,那么所要求的就是红线的长度和(纵坐标差)。

容易想到从上往下扫,然后动态计算目前的折线长度和。每次从大到小加入 \(a_i\) ,那么新增加的折线数量就是目前连续段数的两倍。

注意当最后一个数已经插入时,后面不能再算贡献(绿色紫色虚线部分),同理第一个数已经插入时前面不能再算贡献。

\(f_{i, j, k, 0/1, 0/1}\) 表示目前加入了 \(i\) 个数,已有 \(j\) 个连续段,答案为 \(k\) ,第一个数/最后一个数有没有被加进去,不难做到 \(O(n^2 L)\) 转移。

#include <bits/stdc++.h>
using namespace std;
const int Mod = 1e9 + 7;
const int N = 1e2 + 7, M = 1e3 + 7;

int a[N], f[N][N][M][2][2];

int n, m;

inline int add(int x, int y) {
    x += y;
    
    if (x >= Mod)
        x -= Mod;
    
    return x;
}

inline int dec(int x, int y) {
    x -= y;
    
    if (x < 0)
        x += Mod;
    
    return x;
}

signed main() {
    scanf("%d%d", &n, &m);

    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i);

    sort(a + 1, a + n + 1, greater<int>());
    f[1][1][0][0][0] = f[1][1][0][0][1] = f[1][1][0][1][0] = f[1][1][0][1][1] = 1;

    for (int i = 1; i < n; ++i)
        for (int j = 1; j <= i; ++j)
            for (int k = 0; k <= m; ++k)
                for (int x = 0; x <= 1; ++x)
                    for (int y = 0; y <= 1; ++y) {
                        int res = f[i][j][k][x][y], t = k + (j * 2 - x - y) * (a[i] - a[i + 1]);

                        if (!res || t > m)
                            continue;

                        //新建一个连续段
                        if (j > 1) // 在中间
                            f[i + 1][j + 1][t][x][y] = add(f[i + 1][j + 1][t][x][y], 1ll * (j - 1) * res % Mod);

                        if (!x) { // 在左边
                            f[i + 1][j + 1][t][0][y] = add(f[i + 1][j + 1][t][0][y], res);
                            f[i + 1][j + 1][t][1][y] = add(f[i + 1][j + 1][t][1][y], res);
                        }

                        if (!y) { // 在右边
                            f[i + 1][j + 1][t][x][0] = add(f[i + 1][j + 1][t][x][0], res);
                            f[i + 1][j + 1][t][x][1] = add(f[i + 1][j + 1][t][x][1], res);
                        }

                        // 插入已有连续段一端
                        if (j > 1) // 在中间
                            f[i + 1][j][t][x][y] = add(f[i + 1][j][t][x][y], 2ll * (j - 1) * res % Mod);

                        if (!x) { // 在左边
                            f[i + 1][j][t][0][y] = add(f[i + 1][j][t][0][y], res);
                            f[i + 1][j][t][1][y] = add(f[i + 1][j][t][1][y], res);
                        }

                        if (!y) { // 在右边
                            f[i + 1][j][t][x][0] = add(f[i + 1][j][t][x][0], res);
                            f[i + 1][j][t][x][1] = add(f[i + 1][j][t][x][1], res);
                        }

                        // 合并两个连续段
                        if (j > 1)
                            f[i + 1][j - 1][t][x][y] = add(f[i + 1][j - 1][t][x][y], 1ll * (j - 1) * res % Mod);
                    }
 
    int ans = 0;

    for (int i = 0; i <= m; ++i)
        ans = add(ans, f[n][1][i][1][1]);

    printf("%d", ans);
    return 0;
}

费用提前计算

未来费用仅与当前有关

这类题目通常具有以下特点:

  • 当前决策对未来行动的费用影响只与当前决策有关。
  • 对状态增加一维来记录决策对未来的影响造成的复杂度代价过高。
  • 对未来的代价是线性的关系,根据线性性可以直接累加。

考虑把这个代价看作决策本身的费用,将未来的代价提前计算出来,在决策的时候就计算上它将会带来的代价,并向后传递。

P2365 任务安排

\(n\) 个任务按顺序分批执行,每批任务开始需要一个固定的启动时间 \(S\) 。第 \(i\) 个任务花费的时间是 \(t_i\) ,每个任务的花费是它完成的时刻乘上它自身的费用系数 \(f_i\)。需要找到一个最佳的分批顺序使得总费用最小。

\(n \le 5000\)

\(T_i, F_i\) 为前缀和数组,设 \(g_{i, j}\) 表示前 \(i\) 个任务分为 \(j\) 组的最小花费,则:

\[g_{i, j} = \min_{0 \le k < i} \{ g_{k, j - 1} + (S \times j + T_i) \times (F_i - F_k) \} \]

注意到这个 DP 的状态设计之所以记录 \(j\) 这一维,是因为需要知道前面有多少次启动了机器,即分成了多少批任务。

但是我们并不关心启动了几次机器,只关心到底因为 \(S\) 造成了多少费用。一批任务启动的时间 \(S\) 会累加到后面每一个任务上,所以可以将对后面任务造成的影响,累加到当前的费用中。设 \(g_i\) 表示把前 \(i\) 个任务分成若干个组的最小花费,有转移方程:

\[g_i = \min_{0 \le j < i} \{ g_j + T_i \times (F_i - F_j) + S \times (F_n - F_j) \} \]

时间复杂度 \(O(n^2)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const ll inf = 0x3f3f3f3f3f3f3f3f;
const int N = 5e3 + 7;

ll t[N], f[N], g[N];

int n, s;

template <class T = int>
inline T read() {
    char c = getchar();
    bool sign = (c == '-');
    
    while (c < '0' || c > '9')
        c = getchar(), sign |= (c == '-');
    
    T x = 0;
    
    while ('0' <= c && c <= '9')
        x = (x << 1) + (x << 3) + (c & 15), c = getchar();
    
    return sign ? (~x + 1) : x;
}

signed main() {
    n = read(), s = read();

    for (int i = 1; i <= n; ++i)
        t[i] = t[i - 1] + read(), f[i] = f[i - 1] + read();

    memset(g, inf, sizeof(g));
    g[0] = 0;

    for (int i = 1; i <= n; ++i)
        for (int j = 0; j < i; ++j)
            g[i] = min(g[i], g[j] + t[i] * (f[i] - f[j]) + s * (f[n] - f[j]));

    printf("%lld", g[n]);
    return 0;
}

P4870 [BalticOI 2009 Day1] 甲虫

数轴上有 \(n\) 滴水,每滴水最开始有 \(m\) 的大小。每个单位时间内可以移动一个单位长度,同时所有水减小 \(1\),求最多能喝多少水。

\(n \le 300\)

考虑到这个露水会变成 \(0\) 就不会减小了,所以考虑枚举必吃的露水的数量。这样最优解一定会被枚举到,而且因为取得是最大值,就算减成负数也不影响。

首先对露水排个序。因为已知露水总量,并且时间会对每一个还没有吃掉的露水造成负贡献。因为这个时间不好存储,于是考虑费用提前计算,设 \(f_{l, r, 0/1}\) 表示吃完了 \([l, r]\) 的露水,现在在左/右端点的最多喝水量。令 \(s = n - r + l\) ,则有

\[f_{l, r, 0} = \max(f_{l + 1, r, 0} - (a_{l + 1} - a_l) \times s, f_{l + 1, r, 1} - (a_r - a_l)\times s) + m \\ f_{l, r, 1} = \max(f_{l, r - 1, 1} - (a_r - a_{r - 1}) \times s, f_{l, r - 1, 0} - (a_r - a_l) \times s) + m \]

考虑解释这个方程,可以理解成为周围的露水都还在丢失,在这里减去贡献;相对认为当前这个露水没有损失水量,因为损失的水量我们已经丢去了。时间复杂度 \(O(n^3)\)

#include <bits/stdc++.h>
using namespace std;
const int N = 3e2 + 7;

int f[N][N][2];
int a[N];

int n, m;

signed main() {
    scanf("%d%d", &n, &m);

    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i);

    sort(a + 1, a + 1 + n);
    int ans = 0;

    for (int num = 1; num <= n; ++num) {
        memset(f, 0, sizeof(f));

        for (int i = 1; i <= n; ++i)
            ans = max(ans, f[i][i][0] = f[i][i][1] = m - abs(a[i]) * num);

        for (int len = 2; len <= num; ++len)
            for (int l = 1, r = len; r <= n; ++l, ++r) {
                f[l][r][0] = max(f[l + 1][r][0] + m - (a[l + 1] - a[l]) * (num - r + l),
                    f[l + 1][r][1] + m - (a[r] - a[l]) * (num - r + l));
                f[l][r][1] = max(f[l][r - 1][1] + m - (a[r] - a[r - 1]) * (num - r + l),
                    f[l][r - 1][0] + m - (a[r] - a[l]) * (num - r + l));
                ans = max(ans, max(f[l][r][0], f[l][r][1]));
            }
    }

    printf("%d", ans);
    return 0;
}

CF441E Valera and Number

给出一个数 \(x\)\(n\) 次操作,每次操作有 \(\frac{p}{100}\) 的概率令 \(x \leftarrow 2x\) ,有 \(1 - \frac{p}{100}\) 的概率令 \(x \leftarrow x + 1\) ,求 \(x\) 最终二进制下末尾 \(0\) 的个数的期望。

\(n \le 200\)

由于两个操作会互相影响,考虑倒着操作,这样在一个 \(\times 2\) 前面的 \(+1\) 对后面的影响就变成了 \(+2\) ,进而不会影响后面的最低位的。将末尾的若干个操作看作 \(j\)\(+1\) 后面跟着 \(k\)\(\times 2\)

  • 如果来了一个 \(+1\) 则将 \(j \gets j+1\)
  • 如果来了一个 \(\times 2\)
    • \(j\) 为偶数,则可以视作 \(\frac{j}{2}\)\(+1\),而后面 \(\times 2\) 多了一个。
    • \(j\) 为奇数,则这个最低的 \(1\) 就确定不会改变了,概率乘上 \(k\) 贡献给答案(费用提前计算)。

\(\times 2\)\(j\) 为偶数时可以把通过 \(\times 2\) 多获得的那一个 \(0\) 单独对答案产生 \(1\) 的贡献提前计算,直接乘上概率累加到答案中即可。最后只要加上 \(+1\) 带来的贡献即可,这样就不用存 \(k\) 了。

时间复杂度 \(O(n^2)\)

#include <bits/stdc++.h>
using namespace std;
const int N = 2e2 + 7;

double f[N][N];

double p, ans;
int x, n;

signed main() {
    scanf("%d%d%lf", &x, &n, &p);
    p /= 100, f[0][0] = 1;

    for (int i = 0; i < n; ++i)
        for (int j = 0; j <= i; ++j) {
            if (~j & 1)
                f[i + 1][j / 2] += f[i][j] * p, ans += f[i][j] * p;

            f[i + 1][j + 1] += f[i][j] * (1 - p);
        }

    for (int i = 0; i <= n; ++i)
        ans += __builtin_ctz(x + i) * f[n][i];

    printf("%.7lf", ans);
    return 0;
}

[ARC126D] Pure Straight

给定一个长度为 \(n\) 序列 \(a\) ,有 \(\forall a_i \in [1, k]\) 。每次操作可以交换两个相邻的元素,求最少操作多少次可以使得 \(a\) 中存在一个区间 \([l, l + k - 1]\) ,满足 \(\forall i \in [l, l + k - 1], a_i = i - l + 1\)

\(n \le 200\)\(k \le 16\)

注意到 \(k\) 特别小,考虑状压 DP。设 \(f_{i, S}\) 表示考虑到前 \(i\) 个数,最终答案中已经排好了 \(S\) 二进制位上为 \(1\) 的数并连在了一起。

当决策一个新的数的时候,不妨假设新插入的这个数已经紧贴在先前排好的序列右边了,要插入的话直接按照规则插入,不需要计算从别的地方移动过来的费用。如果放进最终答案,费用就是把它移到相应位置,即以其为结尾的序列的逆序对的个数;如果不放进最终答案,那么代价就是让最后答案中的数跨过它。

但是每个数具体被跨过了几次并不好求,考虑费用提前计算,把被跨过的次数转换为跨过它的数字数量,将费用均摊在每个剩下选中的每个点上。那么就是所有比他小的数都要跨过它一次,或者所有比它大的数都要跨过它一次,两者贪心取 \(\min\) 即可。

#include <bits/stdc++.h>
using namespace std;
const int inf = 0x3f3f3f3f;
const int N = 2e2 + 7, K = 17;

int a[N], f[1 << K | 1];

int n, k;

signed main() {
    scanf("%d%d", &n, &k);

    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i), --a[i];

    memset(f, inf, sizeof(f));
    f[0] = 0;
    int all = (1 << k) - 1;
    #define ppc __builtin_popcount

    for (int i = 1; i <= n; ++i)
        for (int j = all; ~j; --j) {
            if (f[j] == inf)
                continue;

            if (~j >> a[i] & 1)
                f[j | (1 << a[i])] = min(f[j | (1 << a[i])], f[j] + ppc(j >> a[i] << a[i]));

            f[j] += min(ppc(j), ppc(j ^ all));
        }

    #undef ppc
    printf("%d", f[all]);
    return 0;
}

未来费用与未来有关

这一类题目通常有以下特点

  • 未来的决策并不只于当前决策有关,还与未来本身状态相关。
  • 对未来的代价并非线性的关系,不能简单的累加。

通常解决方式为增加一维状态来记录对未来的预测,并计算可能出现的代价,从而在未来能够直接使用。

UVA10559 方块消除 Blocks

给定长度为 \(n\) 的方块序列,每个方块有一个颜色,每次消除一段颜色相同长度为 \(x\) 的方块,并获得 \(x^2\) 的分数,消除后剩下的方块会合并起来。求最大得分。

\(n \le 200\)

注意到这个得分函数是二次的关系,不能直接用现在的决策推未来的费用。不妨设 \(f_{l, r, k}\) 表示消掉了 \([l, r]\) 这个区间,且后面有 \(k\) 个位置和 \(r\) 位置合并在一起消掉了。

  • 如果 \(r\) 和后面 \(k\) 个块一起消掉,就有 \(f_{l, r, k} \gets f_{l, r - 1, 0} + (k + 1)^2\)

  • 如果在 \([l, r - 1]\) 内还有与 \(r\) 颜色相同的块 \(x\) ,那么可以先消除 \([x + 1, r - 1]\) 然后合并 \(x, j\) 再一起消除,就有 \(f_{l, r, k} \gets f_{l, x, k + 1} + f_{x + 1, r - 1, 0}\)

最后答案即为 \(f_{1, n, 0}\) ,时间复杂度 \(O(n^4)\)

#include <bits/stdc++.h>
using namespace std;
const int N = 2e2 + 7;

int f[N][N][N];
int col[N], cnt[N];

int n, m;

template <class T = int>
inline T read() {
    char c = getchar();
    bool sign = (c == '-');
    
    while (c < '0' || c > '9')
        c = getchar(), sign |= (c == '-');
    
    T x = 0;
    
    while ('0' <= c && c <= '9')
        x = (x << 1) + (x << 3) + (c & 15), c = getchar();
    
    return sign ? (~x + 1) : x;
}

int dfs(int l, int r, int k) {
    if (~f[l][r][k])
        return f[l][r][k];

    if (l == r)
        return (cnt[l] + k) * (cnt[l] + k);

    f[l][r][k] = dfs(l, r - 1, 0) + (cnt[r] + k) * (cnt[r] + k);

    for (int i = l; i + 1 <= r - 1; ++i)
        if (col[i] == col[r])
            f[l][r][k] = max(f[l][r][k], dfs(l, i, cnt[r] + k) + dfs(i + 1, r - 1, 0));

    return f[l][r][k];
}

signed main() {
    int T = read();

    for (int task = 1; task <= T; ++task) {
        n = read(), m = 0;

        for (int i = 1, lst = -1; i <= n; ++i) {
            int x = read();

            if (x != lst)
                col[++m] = x, cnt[m] = 0;

            ++cnt[m], lst = x;
        }

        memset(f, -1, sizeof(f));
        printf("Case %d: %d\n", task, dfs(1, m, 0));
    }

    return 0;
}

P3354 [IOI2005] Riv 河流

给定一棵树,\(0\) 为根,点边均带权,可以选 \(k\) 个关键点(根节点本身为关键点且不算在 \(k\) 个内),记每个节点离它最近的关键点祖先为 \(f_i\)(可以为本身),最小化 \(\sum dis_{i \to f_i} \times val_i\)

\(n \le 100\)

\(f_{i, j, 0/1}\) 表示 \(i\) 的子树中选取了 \(j\) 个关键点,是否选取 \(i\) 的最小权值和。但是我们并不容易知道具体有多少个节点选择了 \(i\) 为祖先关键点,无法统计答案。

考虑决策选择的节点是谁,并把它记录在状态中。设 \(f_{u, i, j, 0/1}\) 表示 \(u\) 的子树中选取了 \(j\) 个关键点,其中 \(u\) 的祖先关键点为 \(i\) ,是否选取 \(u\) 的最小权值和。转移时在每个节点上先枚举它的祖先关键点,再枚举在它子树中选了几个关键点,对于每个不同数量的关键点的决策做一个类似于树上背包的东西。时间复杂度 \(O(n^2 k^2)\)

#include <bits/stdc++.h>
using namespace std;
const int N = 1e2 + 7;

struct Graph {
    vector<pair<int, int> > e[N];
    
    inline void insert(int u, int v, int w) {
        e[u].emplace_back(v, w);
    }
} G;

int f[N][N][N][2], a[N], sta[N], dis[N];

int n, m, top;

void dfs(int u) {
    sta[++top] = u;

    for (auto it : G.e[u]) {
        int v = it.first, w = it.second;
        dis[v] = dis[u] + w, dfs(v);

        for (int i = 1; i <= top; ++i)
            for (int j = m; ~j; --j) {
                f[u][sta[i]][j][0] += f[v][sta[i]][0][0], f[u][sta[i]][j][1] += f[v][u][0][0];

                for (int k = 1; k <= j; ++k) {
                    f[u][sta[i]][j][0] = min(f[u][sta[i]][j][0], f[u][sta[i]][j - k][0] + f[v][sta[i]][k][0]);
                    f[u][sta[i]][j][1] = min(f[u][sta[i]][j][1], f[u][sta[i]][j - k][1] + f[v][u][k][0]);
                }
            }
    }

    for (int i = 1; i <= top; ++i) {
        for (int j = m; j; --j)
            f[u][sta[i]][j][0] = min(f[u][sta[i]][j][0] + a[u] * (dis[u] - dis[sta[i]]), f[u][sta[i]][j - 1][1]);
            // 这里用 j - 1 的原因是之前算 f[u][i][j][1] 的时候假设 u 有选,实际上 u 并不算在 j 里面,故要用 j - 1

        f[u][sta[i]][0][0] += a[u] * (dis[u] - dis[sta[i]]);
    } // 将 f[u][i][0 / 1] 都并到 f[u][i][j][0] 上,因为之后不在乎 u 是否有选

    --top;
}

signed main() {
    scanf("%d%d", &n, &m);

    for (int i = 1; i <= n; ++i) {
        int f, w;
        scanf("%d%d%d", a + i, &f, &w);
        G.insert(f, i, w);
    }

    dfs(0);
    printf("%d", f[0][0][m][0]);
    return 0;
}

整体 DP

整体 DP 通过类似数据结构维护序列的方式将 DP 状态中的一维压入数据结构,并通过少量单点或区间的操作达到对该维所有状态进行转移的目的,从而降低转移的复杂度。

对于一类树上连通块计数问题,的朴素维护方式是:对每个点维护一个 DP 数组,合并两个子树时就把每个下标的对应信息合并,添加一个点时需要对某个位置进行修改。对于一个连通块,在其深度最浅的位置统计信息。

注意到此类问题虽然状态数为 \(O(nm)\) ,但是大部分状态只是对子树信息的合并,而插入单点信息只有 \(O(n)\) 次。此时可以用线段树维护每个点的 DP 数组,于是维护信息转变为线段树合并与单点插入。

这样一来,线段树中插入的总点数是 \(O(n \log n)\) 级别的,而线段树合并的复杂度不高于其插入总复杂度。

P9400 「DBOI」Round 1 三班不一般

给出 \(n\)\([l_i, r_i]\) ,求有多少长度为 \(n\) 的序列满足 \(a_i \in [l_i, r_i]\) 且不存在长度为 \(a\) 的子序列满足值均大于 \(b\)

\(n \le 2 \times 10^5\)\(l_i, r_i, b \le 10^9\)

\(f_{i, j}\) 表示前 \(i\) 个位置,值大于 \(b\) 的后缀长度为 \(j\) ,记 \(c_{i, 0}\) 表示这一位能填 \(\le b\) 的数字数量,\(c_{i, 1}\) 表示这一位能填 \(>b\) 的数字数量。则:

\[f_{i, 0} = \sum_{j = 0}^{a - 1} f_{i - 1, j} \times c_{i, 0} \\ f_{i, j} = f_{i - 1, j - 1} \times c_{i, 1} \]

前者直接对整个数组求和即可,后者需要支持区间平移的操作,不难用平衡树优化 DP 做到 \(O(n \log n)\)

#include <bits/stdc++.h>
typedef unsigned int uint;
using namespace std;
const int Mod = 998244353;
const int N = 2e5 + 7;

int n, a, b;

template <class T = int>
inline T read() {
    char c = getchar();
    bool sign = (c == '-');
    
    while (c < '0' || c > '9')
        c = getchar(), sign |= (c == '-');
    
    T x = 0;
    
    while ('0' <= c && c <= '9')
        x = (x << 1) + (x << 3) + (c & 15), c = getchar();
    
    return sign ? (~x + 1) : x;
}

inline int add(int x, int y) {
    x += y;
    
    if (x >= Mod)
        x -= Mod;
    
    return x;
}

inline int dec(int x, int y) {
    x -= y;
    
    if (x < 0)
        x += Mod;
    
    return x;
}

namespace fhqTreap {
uint dat[N];
int lc[N], rc[N], siz[N], val[N], s[N], tag[N];

mt19937 myrand(time(0));
int root, tot;

inline int newnode(int k) {
    dat[++tot] = myrand(), val[tot] = s[tot] = k, tag[tot] = 1, siz[tot] = 1;
    return tot;
}

inline void pushup(int x) {
    siz[x] = 1, s[x] = val[x];

    if (lc[x])
        siz[x] += siz[lc[x]], s[x] = add(s[x], s[lc[x]]);

    if (rc[x])
        siz[x] += siz[rc[x]], s[x] = add(s[x], s[rc[x]]);
}

inline void spread(int x, int k) {
    val[x] = 1ll * val[x] * k % Mod;
    s[x] = 1ll * s[x] * k % Mod;
    tag[x] = 1ll * tag[x] * k % Mod;
}

inline void pushdown(int x) {
    if (tag[x] != 1) {
        if (lc[x])
            spread(lc[x], tag[x]);

        if (rc[x])
            spread(rc[x], tag[x]);

        tag[x] = 1;
    }
}

void split(int x, int k, int &a, int &b) {
    if (!x) {
        a = b = 0;
        return;
    }

    pushdown(x);

    if (siz[lc[x]] + 1 <= k)
        a = x, split(rc[x], k - siz[lc[x]] - 1, rc[a], b);
    else
        b = x, split(lc[x], k, a, lc[b]);

    pushup(x);
}

int merge(int a, int b) {
    if (!a || !b)
        return a | b;

    pushdown(a), pushdown(b);

    if (dat[a] > dat[b])
        return rc[a] = merge(rc[a], b), pushup(a), a;
    else
        return lc[b] = merge(a, lc[b]), pushup(b), b;
}

inline void prework() {
    root = newnode(1);

    for (int i = 1; i < a; ++i)
        root = merge(root, newnode(0));
}

inline void update(int k) {
    int X, Y;
    split(root, a - 1, X, Y);
    val[Y] = s[Y] = k;
    root = merge(Y, X);
}
} // namespace fhqTreap

signed main() {
    n = read(), a = read(), b = read();
    fhqTreap::prework();

    for (int i = 1; i <= n; ++i) {
        int l = read(), r = read(), c0 = (l <= b ? min(r, b) - l + 1 : 0), c1 = (r - l + 1) - c0,
            res = 1ll * fhqTreap::s[fhqTreap::root] * c0 % Mod;
        fhqTreap::spread(fhqTreap::root, c1), fhqTreap::update(res);
    }

    printf("%d", fhqTreap::s[fhqTreap::root]);
    return 0;
}

P8476 「GLR-R3」惊蛰

给定非负整数序列 \(\{a_n\}\),定义函数 \(F(x,y) = \begin{cases} x - y, & x \ge y \\ C, & x < y \end{cases}\) ,其中 \(C\) 是给定常数,求所有 \(b\) 序列中 \(\sum_{i = 1}^n F(b_i, a_i)\) 最小值。

\(n \le 10^6\)

首先发现当 \(b_i \in \{a_1, a_2, \cdots, a_n \}\) 时一定不劣。证明就是考虑一个最优解,把不满足的位置变小答案一定不增。

\(f_{i, j}\) 表示前 \(i\) 个位置,\(b_i = j\) 的代价和,则:

\[f_{i, j} = \min_{k \ge j} f_{i - 1, k} + F(j, a_i) \]

其中 \(F(x, y)\) 为代价,可以做一遍后缀 \(\min\) 后单次线性转移,暴力转移是 \(O(n^2)\) 的。

考虑维护 \(f\) 的后缀 \(\min\) ,发现代价函数 \(F(x, y)\) 是关于 \(x\) 的分段函数:

  • \(x < y\)\(f_{i, j} \gets f_{i - 1, j} + C\)
  • \(x \ge y\)\(f_{i, j} \gets f_{i - 1, j} + j - a_i\)

不难发现两部分分别具有单调性,所以操作结束后 \(f_{i, j}\)\(j = a_i\) 为界分为两段单调序列,线段树二分维护后缀 \(\min\) 即可,时间复杂度 \(O(n \log n)\)

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 1e6 + 7;

vector<int> vec;

int a[N];

int n, c;

template <class T = int>
inline T read() {
    char c = getchar();
    bool sign = (c == '-');
    
    while (c < '0' || c > '9')
        c = getchar(), sign |= (c == '-');
    
    T x = 0;
    
    while ('0' <= c && c <= '9')
        x = (x << 1) + (x << 3) + (c & 15), c = getchar();
    
    return sign ? (~x + 1) : x;
}

namespace SMT {
struct Node {
    ll mn, mx, cov, tag, cnt;
} nd[N << 2];

inline int ls(int x) {
    return x << 1;
}

inline int rs(int x) {
    return x << 1 | 1;
}

inline void pushup(int x) {
    nd[x].mn = min(nd[ls(x)].mn, nd[rs(x)].mn);
    nd[x].mx = max(nd[rs(x)].mx, nd[rs(x)].mx);
}

inline void spread(int x, int l, int r, ll cv, ll tg, ll ct) {
    if (~cv)
        nd[x] = (Node) {cv, cv, cv, 0, 0};
    
    if (tg)
        nd[x].mn += tg, nd[x].mx += tg, nd[x].tag += tg;
    
    if (ct)
        nd[x].mn += ct * vec[l], nd[x].mx += ct * vec[r], nd[x].cnt += ct;
}

inline void pushdown(int x, int l, int r) {
    int mid = (l + r) >> 1;
    spread(ls(x), l, mid, nd[x].cov, nd[x].tag, nd[x].cnt);
    spread(rs(x), mid + 1, r, nd[x].cov, nd[x].tag, nd[x].cnt);
    nd[x].cov = -1, nd[x].tag = nd[x].cnt = 0;
}

void build(int x, int l, int r) {
    nd[x].cov = -1;

    if (l == r)
        return;

    int mid = (l + r) >> 1;
    build(ls(x), l, mid), build(rs(x), mid + 1, r);
}

inline void update(int x, int nl, int nr, int l, int r, ll cv, ll tg, ll ct) {
    if (l <= nl && nr <= r) {
        spread(x, nl, nr, cv, tg, ct);
        return;
    }

    pushdown(x, nl, nr);
    int mid = (nl + nr) >> 1;

    if (l <= mid)
        update(ls(x), nl, mid, l, r, cv, tg, ct);
    
    if (r > mid)
        update(rs(x), mid + 1, nr, l, r, cv, tg, ct);

    pushup(x);
}

int search(int x, int nl, int nr, ll k) {
    if (nd[x].mx < k)
        return nr + 1;

    if (nl == nr)
        return nl;

    pushdown(x, nl, nr);
    int mid = (nl + nr) >> 1;
    return nd[ls(x)].mx >= k ? search(ls(x), nl, mid, k) : search(rs(x), mid + 1, nr, k);
}

ll query(int x, int nl, int nr, int l, int r) {
    if (l <= nl && nr <= r)
        return nd[x].mn;

    pushdown(x, nl, nr);
    int mid = (nl + nr) >> 1;

    if (r <= mid)
        return query(ls(x), nl, mid, l, r);
    else if (l > mid)
        return query(rs(x), mid + 1, nr, l, r);
    else
        return min(query(ls(x), nl, mid, l, r), query(rs(x), mid + 1, nr, l, r));
}
} // namespace SMT

signed main() {
    n = read(), c = read();

    for (int i = 1; i <= n; ++i)
        vec.emplace_back(a[i] = read());

    sort(vec.begin(), vec.end());
    vec.erase(unique(vec.begin(), vec.end()), vec.end());
    int m = vec.size() - 1;
    SMT::build(1, 0, m);

    for (int i = 1; i <= n; ++i) {
        int x = lower_bound(vec.begin(), vec.end(), a[i]) - vec.begin();

        if (x)
            SMT::update(1, 0, m, 0, x - 1, -1, c, 0);

        SMT::update(1, 0, m, x, m, -1, -a[i], 0), SMT::update(1, 0, m, x, m, -1, 0, 1);
        ll rmn = SMT::query(1, 0, m, x, m);
        int cur = SMT::search(1, 0, m, rmn);

        if (cur < x)
            SMT::update(1, 0, m, cur, x - 1, rmn, 0, 0);
    }

    printf("%lld", SMT::nd[1].mn);
    return 0;
}

P6773 [NOI2020] 命运

给定一棵树和 \(m\) 条限制,需要给每条边赋上一个 \(0/1\) 的权值,对于每个限制 \((u_i, v_i)\) (满足 \(u_i\)\(v_i\) 的祖先),满足 \(u_i \to v_i\) 的路径上至少有一条 \(1\) 边。求方案数。

\(n, m \le 5 \times 10^5\)

考虑对于 \(v\) 的所有限制,发现只要保留深度最深的 \(u\) 即可。

\(f_{u, i}\) 表示考虑 \(u\) 子树内赋值边的方案,限制下端在子树内且未被满足的限制中上端最深深度为 \(i\) 的方案数。讨论 \((u, v)\) 边的权值可以得到转移:

\[f_{u, i} \gets \sum_{j = 0}^{dep_u} f_{u, i} f_{v, j} + \sum_{j = 0}^i f_{u, i} f_{v, j} + \sum_{j = 0}^{i - 1} f_{u, j} f_{v, i} \]

\(g_{u, i} = \sum_{j = 0}^i f_{u, j}\) ,则:

\[f_{u, i} \gets f_{u, i} (g_{v, dep_u} + g_{v, i}) + g_{u, i - 1} f_{v, i} \]

线段树合并做整体 DP 即可做到 \(O(n \log n)\)

#include <bits/stdc++.h>
using namespace std;
const int Mod = 998244353;
const int N = 5e5 + 7;

struct Graph {
    vector<int> e[N];
    
    inline void insert(int u, int v) {
        e[u].emplace_back(v);
    }
} G;

int dep[N], mxd[N];

int n, m;

template <class T = int>
inline T read() {
    char c = getchar();
    bool sign = (c == '-');
    
    while (c < '0' || c > '9')
        c = getchar(), sign |= (c == '-');
    
    T x = 0;
    
    while ('0' <= c && c <= '9')
        x = (x << 1) + (x << 3) + (c & 15), c = getchar();
    
    return sign ? (~x + 1) : x;
}

inline int add(int x, int y) {
    x += y;
    
    if (x >= Mod)
        x -= Mod;
    
    return x;
}

inline int dec(int x, int y) {
    x -= y;
    
    if (x < 0)
        x += Mod;
    
    return x;
}

void dfs1(int u, int f) {
    dep[u] = dep[f] + 1;

    for (int v : G.e[u])
        if (v != f)
            dfs1(v, u);
}

namespace SMT {
const int S = 3e7 + 7;

int lc[S], rc[S], s[S], tag[S];
int rt[N];

int tot;

inline void pushup(int x) {
    s[x] = add(s[lc[x]], s[rc[x]]);
}

inline void spread(int x, int k) {
    s[x] = 1ll * s[x] * k % Mod, tag[x] = 1ll * tag[x] * k % Mod;
}

inline void pushdown(int x) {
    if (tag[x] != 1) {
        if (lc[x])
            spread(lc[x], tag[x]);

        if (rc[x])
            spread(rc[x], tag[x]);

        tag[x] = 1;
    }
}

void update(int &x, int nl, int nr, int pos, int k) {
    if (!x)
        tag[x = ++tot] = 1;

    if (nl == nr) {
        s[x] = k;
        return;
    }

    int mid = (nl + nr) >> 1;

    if (pos <= mid)
        update(lc[x], nl, mid, pos, k);
    else
        update(rc[x], mid + 1, nr, pos, k);

    pushup(x);
}

int query(int x, int nl, int nr, int l, int r) {
    if (!x)
        return 0;

    if (l <= nl && nr <= r)
        return s[x];

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (r <= mid)
        return query(lc[x], nl, mid, l, r);
    else if (l > mid)
        return query(rc[x], mid + 1, nr, l, r);
    else
        return add(query(lc[x], nl, mid, l, r), query(rc[x], mid + 1, nr, l, r));
}

int merge(int a, int b, int l, int r, int &resu, int &resv) {
    if (!a && !b)
        return 0;
    else if (!a) {
        resv = add(resv, s[b]), spread(b, resu);
        return b;
    } else if (!b) {
        resu = add(resu, s[a]), spread(a, resv);
        return a;
    }

    if (l == r) {
        int su = s[a], sv = s[b];
        resv = add(resv, sv);
        s[a] = add(1ll * s[a] * resv % Mod, 1ll * s[b] * resu % Mod);
        resu = add(resu, su);
        return a;
    }

    pushdown(a), pushdown(b);
    int mid = (l + r) >> 1;
    lc[a] = merge(lc[a], lc[b], l, mid, resu, resv);
    rc[a] = merge(rc[a], rc[b], mid + 1, r, resu, resv);
    return pushup(a), a;
}
} // namespace SMT

void dfs2(int u, int f) {
    SMT::update(SMT::rt[u], 0, n, mxd[u], 1);

    for (int v : G.e[u]) {
        if (v == f)
            continue;

        dfs2(v, u);
        int resu = 0, resv = SMT::query(SMT::rt[v], 0, n, 0, dep[u]);
        SMT::rt[u] = SMT::merge(SMT::rt[u], SMT::rt[v], 0, n, resu, resv);
    }
}

signed main() {
    n = read();

    for (int i = 1; i < n; ++i) {
        int u = read(), v = read();
        G.insert(u, v), G.insert(v, u);
    }

    dfs1(1, 0), m = read();

    for (int i = 1; i <= m; ++i) {
        int u = read(), v = read();
        mxd[v] = max(mxd[v], dep[u]);
    }

    dfs2(1, 0);
    printf("%d", SMT::query(SMT::rt[1], 0, n, 0, 0));
    return 0;
}

P5298 [PKUWC2018] Minimax

有一棵 \(n\) 个点的有根树,根是 \(1\) ,且每个结点最多有两个子结点。

定义结点 \(x\) 的权值为:

  • \(x\) 没有子结点,那么它的权值会在输入里给出,保证这类点中每个结点的权值互不相同。

  • \(x\) 有子结点,那么它的权值有 \(p_x\) 的概率是它的子结点的权值的最大值,有 \(1-p_x\) 的概率是它的子结点的权值的最小值。

假设 \(1\) 号结点的权值有 \(m\) 种可能性,权值第 \(i\) 小的可能性的权值是 \(V_i\),它的概率为 \(D_i\),求 \(\sum_{i = 1}^m i \times V_i \times D_i^2\)

\(n \le 3 \times 10^5\)

\(f_{u, i}\) 表示 \(u\) 权值为 \(i\) 的概率,\(m\) 表示叶子数量,则:

\[f_{u, i} = f_{lc, i} \times (p_u \times \sum_{j = 1}^{i - 1} f_{rc, j} + (1 - p_u) \times \sum_{j = i + 1}^m f_{rc, j}) + f_{rc, i} \times (p_u \times \sum_{j = 1}^{i - 1} f_{lc, j} + (1 - p_u) \times \sum_{j = i + 1}^m f_{lc, j}) \]

线段树合并维护整体 DP 即可,合并时需要记录前后缀的贡献,时间复杂度 \(O(n \log n)\)

#include <bits/stdc++.h>
using namespace std;
const int Mod = 998244353;
const int N = 3e5 + 7;

struct Graph {
    vector<int> e[N];
    
    inline void insert(int u, int v) {
        e[u].emplace_back(v);
    }
} G;

int fa[N], a[N], ans[N];

int n, m;

template <class T = int>
inline T read() {
    char c = getchar();
    bool sign = (c == '-');
    
    while (c < '0' || c > '9')
        c = getchar(), sign |= (c == '-');
    
    T x = 0;
    
    while ('0' <= c && c <= '9')
        x = (x << 1) + (x << 3) + (c & 15), c = getchar();
    
    return sign ? (~x + 1) : x;
}

inline int add(int x, int y) {
    x += y;
    
    if (x >= Mod)
        x -= Mod;
    
    return x;
}

inline int dec(int x, int y) {
    x -= y;
    
    if (x < 0)
        x += Mod;
    
    return x;
}

inline int mi(int a, int b) {
    int res = 1;
    
    for (; b; b >>= 1, a = 1ll * a * a % Mod)
        if (b & 1)
            res = 1ll * res * a % Mod;
    
    return res;
}

namespace SMT {
const int S = 3e7 + 7;

int lc[S], rc[S], s[S], tag[S];
int rt[N];

int tot;

inline void pushup(int x) {
    s[x] = add(s[lc[x]], s[rc[x]]);
}

inline void spread(int x, int k) {
    s[x] = 1ll * s[x] * k % Mod, tag[x] = 1ll * tag[x] * k % Mod;
}

inline void pushdown(int x) {
    if (tag[x] != 1) {
        if (lc[x])
            spread(lc[x], tag[x]);

        if (rc[x])
            spread(rc[x], tag[x]);

        tag[x] = 1;
    }
}

void update(int &x, int nl, int nr, int pos, int k) {
    if (!x)
        tag[x = ++tot] = 1;

    if (nl == nr) {
        s[x] = k;
        return;
    }

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (pos <= mid)
        update(lc[x], nl, mid, pos, k);
    else
        update(rc[x], mid + 1, nr, pos, k);

    pushup(x);
}

int merge(int a, int b, int l, int r, int resa, int resb, int p) {
    if (!a && !b)
        return 0;
    else if (!a)
        return spread(b, resa), b;
    else if (!b)
        return spread(a, resb), a;

    pushdown(a), pushdown(b);
    int mid = (l + r) >> 1, lca = s[lc[a]], rca = s[rc[a]], lcb = s[lc[b]], rcb = s[rc[b]];
    lc[a] = merge(lc[a], lc[b], l, mid, add(resa, 1ll * dec(1, p) * rca % Mod), 
        add(resb, 1ll * dec(1, p) * rcb % Mod), p);
    rc[a] = merge(rc[a], rc[b], mid + 1, r, add(resa, 1ll * p * lca % Mod), 
        add(resb, 1ll * p * lcb % Mod), p);
    return pushup(a), a;
}

void dfs(int x, int l, int r) {
    if (!x)
        return;

    if (l == r) {
        ans[l] = s[x];
        return;
    }

    pushdown(x);
    int mid = (l + r) >> 1;
    dfs(lc[x], l, mid), dfs(rc[x], mid + 1, r);
}
} // namespace SMT

void dfs(int u) {
    if (G.e[u].empty())
        SMT::update(SMT::rt[u], 1, m, a[u], 1);
    else if (G.e[u].size() == 1)
        dfs(G.e[u][0]), SMT::rt[u] = SMT::rt[G.e[u][0]];
    else {
        dfs(G.e[u][0]), dfs(G.e[u][1]);
        SMT::rt[u] = SMT::merge(SMT::rt[G.e[u][0]], SMT::rt[G.e[u][1]], 1, m, 0, 0, a[u]);
    }
}

signed main() {
    n = read();

    for (int i = 1; i <= n; ++i) {
        fa[i] = read();

        if (fa[i])
            G.insert(fa[i], i);
    }

    for (int i = 1; i <= n; ++i)
        a[i] = read();

    vector<int> vec;

    for (int i = 1; i <= n; ++i) {
        if (G.e[i].empty())
            vec.emplace_back(a[i]);
        else
            a[i] = 1ll * a[i] * mi(1e4, Mod - 2) % Mod;
    }

    sort(vec.begin(), vec.end()), m = vec.size();

    for (int i = 1; i <= n; ++i)
        if (G.e[i].empty())
            a[i] = lower_bound(vec.begin(), vec.end(), a[i]) - vec.begin() + 1;

    dfs(1), SMT::dfs(SMT::rt[1], 1, m);
    int answer = 0;

    for (int i = 1; i <= m; ++i)
        answer = add(answer, 1ll * i * vec[i - 1] % Mod * ans[i] % Mod * ans[i] % Mod);

    printf("%d", answer);
    return 0;
}

P11149 [THUWC 2018] 城市规划

给出一棵树,每个点有颜色,求有多少树上连通块包含不超过两种颜色。

\(n \le 5 \times 10^5\)

\(f_{u, i}\) 表示 \(u\) 子树中选一个两种颜色分别为 \(col_u\)\(i\) 的方案数,考虑 \(u\) 的一个儿子 \(v\) 的贡献:

  • \(col_u = col_v\)

    \[f_{u, i} \gets \begin{cases} f_{u, i} \times (f _{v, i} + f_{v, col_v} + 1) + f_{u, col_u} \times f_{v, i} & (i \ne col_u) \\ f_{u, i} \times (f_{v, i} + 1) & (i = col_u) \end{cases} \]

  • \(col_u \ne col_v\)

    \[f_{u, i} \gets \begin{cases} f_{u, i} & (i \ne col_v) \\ f_{u, i} \times (f_{v, col_u} + f_{v, i} + 1) + f_{u, col_u} \times (f_{v, col_u} + f_{v, i}) & (i = col_v) \end{cases} \]

观察到不同的颜色是彼此独立的,可以用线段树来维护一个点上的所有 DP 值,第一类转移用线段树合并,第二类转移用单点插入即可。

时间复杂度 \(O(n \log n)\)

#include <bits/stdc++.h>
using namespace std;
const int Mod = 998244353;
const int N = 5e5 + 7;

struct Graph {
    vector<int> e[N];
    
    inline void insert(int u, int v) {
        e[u].emplace_back(v);
    }
} G;

int col[N];

int n, ans;

template <class T = int>
inline T read() {
    char c = getchar();
    bool sign = (c == '-');
    
    while (c < '0' || c > '9')
        c = getchar(), sign |= (c == '-');
    
    T x = 0;
    
    while ('0' <= c && c <= '9')
        x = (x << 1) + (x << 3) + (c & 15), c = getchar();
    
    return sign ? (~x + 1) : x;
}

inline int add(int x, int y) {
    x += y;
    
    if (x >= Mod)
        x -= Mod;
    
    return x;
}

inline int dec(int x, int y) {
    x -= y;
    
    if (x < 0)
        x += Mod;
    
    return x;
}

inline int mi(int a, int b) {
    int res = 1;
    
    for (; b; b >>= 1, a = 1ll * a * a % Mod)
        if (b & 1)
            res = 1ll * res * a % Mod;
    
    return res;
}

namespace SMT {
const int S = 3e7 + 7;

int lc[S], rc[S], s[S], tag[S];
int rt[N];

int tot;

inline void pushup(int x) {
    s[x] = 0;

    if (lc[x])
        s[x] = add(s[x], s[lc[x]]);

    if (rc[x])
        s[x] = add(s[x], s[rc[x]]);
}

inline void spread(int x, int k) {
    s[x] = 1ll * s[x] * k % Mod, tag[x] = 1ll * tag[x] * k % Mod;
}

inline void pushdown(int x) {
    if (tag[x] != 1) {
        if (lc[x])
            spread(lc[x], tag[x]);

        if (rc[x])
            spread(rc[x], tag[x]);

        tag[x] = 1;
    }
}

void build(int &x, int l, int r) {
    tag[x = ++tot] = 1;

    if (l == r)
        return;

    int mid = (l + r) >> 1;
    build(lc[x], l, mid), build(rc[x], mid + 1, r);
}

void update(int &x, int nl, int nr, int pos, int k) {
    if (!x)
        tag[x = ++tot] = 1;

    if (nl == nr) {
        s[x] = k;
        return;
    }

    pushdown(x);
    int mid = (nl + nr) >> 1;

    if (pos <= mid)
        update(lc[x], nl, mid, pos, k);
    else
        update(rc[x], mid + 1, nr, pos, k);

    pushup(x);
}

int merge(int a, int b, int l, int r, int fu, int fv) {
    if (!a && !b)
        return 0;
    else if (!a)
        return spread(b, fu), b;
    else if (!b)
        return spread(a, add(fv, 1)), a;

    if (l == r) {
        s[a] = add(1ll * s[a] * add(add(s[b], fv), 1) % Mod, 1ll * s[b] * fu % Mod);
        return a;
    }

    pushdown(a), pushdown(b);
    int mid = (l + r) >> 1;
    lc[a] = merge(lc[a], lc[b], l, mid, fu, fv);
    rc[a] = merge(rc[a], rc[b], mid + 1, r, fu, fv);
    return pushup(a), a;
}

int query(int x, int nl, int nr, int pos) {
    if (!x)
        return 0;

    if (nl == nr)
        return s[x];

    pushdown(x);
    int mid = (nl + nr) >> 1;
    return pos <= mid ? query(lc[x], nl, mid, pos) : query(rc[x], mid + 1, nr, pos);
}
} // namespace SMT

using SMT::rt;

void dfs(int u, int fa) {
    SMT::update(rt[u], 1, n, col[u], 1);

    for (int v : G.e[u]) {
        if (v == fa)
            continue;

        dfs(v, u);

        if (col[v] == col[u]) {
            int fu = SMT::query(rt[u], 1, n, col[u]), fv = SMT::query(rt[v], 1, n, col[v]);
            rt[u] = SMT::merge(rt[u], rt[v], 1, n, fu, fv);
            SMT::update(rt[u], 1, n, col[u], 1ll * fu * add(fv, 1) % Mod);
        } else {
            int fuu = SMT::query(rt[u], 1, n, col[u]), fuv = SMT::query(rt[u], 1, n, col[v]),
                fvu = SMT::query(rt[v], 1, n, col[u]), fvv = SMT::query(rt[v], 1, n, col[v]);
            SMT::update(rt[u], 1, n, col[v], 
                add(1ll * fuv * add(add(fvu, fvv), 1) % Mod, 1ll * fuu * add(fvu, fvv) % Mod));
        }
    }

    ans = add(ans, SMT::s[rt[u]]);

}

signed main() {
    n = read();

    for (int i = 1; i <= n; ++i)
        col[i] = read();

    for (int i = 1; i < n; ++i) {
        int u = read(), v = read();
        G.insert(u, v), G.insert(v, u);
    }

    dfs(1, 0);
    printf("%d\n", ans);
    return 0;
}

CF809D Hitchhiking in the Baltic States

给出 \(n\) 个区间 \([l_i, r_i]\) ,求一个序列 \(a_{1 \sim n}\) 满足 \(a_i \in [l_i, r_i]\) ,并最大化最长严格上升子序列。

\(n \le 3 \times 10^5\)

考虑按顺序往最长严格上升子序列末尾加数,设 \(f_i\) 表示长度为 \(i\) 的最长严格上升子序列的最后一位的最小值,加入 \(a_i \in [l_i, r_i]\) 时分类讨论:

  • \(f_{j - 1} < l_i - 1\) :令 \(f_j = \min(f_j, l)\)
    • 事实上只有最末端的 \(f_j\) 会被如此更新,因为 \(f\) 具有单调性。
  • \(l_i - 1 \le f_{j - 1} \le r - 1\) :令 \(f_j = f_{j - 1} + 1\)
    • 不用取 \(\min\) 是因为 \(f_j \ge f_{j - 1} + 1\)
  • \(f_{j - 1} > r - 1\) :无法转移。

以上转移在平衡树上可以表示为:

  • \(f_j \in [l_i - 1, r_i - 1]\) 的部分整体 \(+1\)
  • 插入 \(l\) ,因而使得 \(f_j \in [l_i - 1, r_i - 1]\) 的部分整体右移一位。
  • 删除第一个 \(\ge r\)\(f_j\) ,保证后面的 DP 值不变。

不难用 fhq-Treap 做到 \(O(n \log n)\) ,答案即为最后平衡树的大小。

#include <bits/stdc++.h>
typedef unsigned int uint;
using namespace std;
const int N = 3e5 + 7;

mt19937 myrand(time(0));
int n;

namespace fhqTreap {
uint dat[N];
int lc[N], rc[N], val[N], siz[N], tag[N];

int root, tot;

inline int newnode(int k) {
    dat[++tot] = myrand(), val[tot] = k, siz[tot] = 1;
    return tot;
}

inline void pushup(int x) {
    siz[x] = siz[lc[x]] + siz[rc[x]] + 1;
}

inline void spread(int x, int k) {
    val[x] += k, tag[x] += k;
}

inline void pushdown(int x) {
    if (tag[x]) {
        if (lc[x])
            spread(lc[x], tag[x]);

        if (rc[x])
            spread(rc[x], tag[x]);

        tag[x] = 0;
    }
}

void split_val(int x, int k, int &a, int &b) {
    if (!x) {
        a = b = 0;
        return;
    }

    pushdown(x);

    if (val[x] <= k)
        a = x, split_val(rc[x], k, rc[a], b);
    else
        b = x, split_val(lc[x], k, a, lc[b]);

    pushup(x);
}

void split_siz(int x, int k, int &a, int &b) {
    if (!x) {
        a = b = 0;
        return;
    }

    pushdown(x);

    if (k >= siz[lc[x]] + 1)
        a = x, split_siz(rc[x], k - siz[lc[x]] - 1, rc[a], b);
    else
        b = x, split_siz(lc[x], k, a, lc[b]);

    pushup(x);
}

int merge(int a, int b) {
    if (!a || !b)
        return a | b;

    pushdown(a), pushdown(b);

    if (dat[a] > dat[b])
        return rc[a] = merge(rc[a], b), pushup(a), a;
    else
        return lc[b] = merge(a, lc[b]), pushup(b), b;
}

inline void update(int l, int r) {
    int a, b, c, d;
    split_val(root, r - 1, a, c), split_val(a, l - 2, a, b);
    spread(b, 1), split_siz(c, 1, c, d);
    root = merge(merge(a, newnode(l)), merge(b, d));
}
} // namespace fhqTreap

signed main() {
    scanf("%d", &n);

    for (int i = 1; i <= n; ++i) {
        int l, r;
        scanf("%d%d", &l, &r);
        fhqTreap::update(l, r);
    }

    printf("%d", fhqTreap::siz[fhqTreap::root]);
    return 0;
}
posted @ 2025-02-14 21:58  wshcl  阅读(5)  评论(0)    收藏  举报