学习笔记:拉格朗日插值法

多项式插值,即对已知的 \(n + 1\) 个点 \((x_1,y_1),(x_2,y_2),\dots,(x_{n+1},y_{n+1})\),求 \(n\) 次多项式 \(f(x)\),满足对于任意 \(1\le i\le n+1\)\(f(x_i)=y_i\)

拉格朗日插值法

设点 \((x_i,y_i)\)\(x\) 轴上的投影点为 \(P(x_i,0)\)

考虑构造 \(n+1\) 个函数 \(f_1(x),f_2(x),\dots,f_{n+1}(x)\),使得对于第 \(i\) 个函数 \(f_i(x) = \begin{cases}y_i & x=x_i \\ 0 & x \neq x_i\end{cases}\),即图像经过 \((x_i,y_i)\) 和其他不为 \(i\) 的点的投影。

那么 \(f(x)=\sum\limits_{i=1}^{n+1} f_i(x)\)

\(f_i(x)=a\prod\limits_{j\neq i} (x-x_j)\),将 \((x_i,y_i)\) 代入得到 \(a=\frac{y_i}{\prod\limits_{j\neq i} (x_i-x_j)}\)

于是 \(f_i(x)=y_i \prod\limits_{j\neq i}\frac{(x-x_j)}{(x_i-x_j)}\)

那么拉格朗日插值的形式为 \(f(x)=\sum\limits_{i=1}^{n+1} y_i \prod\limits_{j\neq i}\frac{(x-x_j)}{(x_i-x_j)}\)

LGP4781 【模板】拉格朗日插值

朴素的拉插复杂度为 \(O(n^2)\)

Code
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;

const int N = 2e3 + 5, MOD = 998244353;

int n, k, x[N], y[N];

int fpow(int a, int b) {
    int ans = 1;
    for (; b; b >>= 1) {
        if (b & 1) ans = (LL)ans * a % MOD;
        a = (LL)a * a % MOD;
    }
    return ans;
}

int main() {
    cin >> n >> k;
    for (int i = 1; i <= n; i++) {
        cin >> x[i] >> y[i];
    }
    int ans = 0;
    for (int i = 1; i <= n; i++) {
        int res1 = 1, res2 = 1;
        for (int j = 1; j <= n; j++) {
            if (j == i) continue;
            res1 = (LL)res1 * (k - x[j]) % MOD;
            res2 = (LL)res2 * (x[i] - x[j]) % MOD;
        }
        ans = (ans + (LL)y[i] * res1 % MOD * fpow(res2, MOD - 2) % MOD) % MOD;
    }
    cout << (ans + MOD) % MOD << '\n';
    return 0;
}

给定的横坐标为连续整数时的拉插

\(f(x)=\sum\limits_{i=1}^{n+1} y_i \prod\limits_{j\neq i}\frac{(x-x_j)}{(x_i-x_j)} = \sum\limits_{i=1}^{n+1} y_i \prod\limits_{j\neq i}\frac{(x-x_j)}{(i-j)}\)

分子部分的连乘可以表示为 \(\frac{\prod\limits_{j=1}^{n+1} x-x_j}{x-x_i}\),分母部分的连乘可以表示为 \((-1)^{n + 1 - i}(i-1)!(n+1-i)!\)

于是 \(f(x) = \sum\limits_{i=1}^{n+1} (-1)^{n + 1 - i}y_i \frac{\prod\limits_{j=1}^{n+1} x-x_j}{(x-x_i)(i-1)!(n+1-i)!}\)

预处理 \(x-x_i\) 的前后缀积,阶乘逆元,即可做到 \(O(n)\)

CF622F The Sum of the k-th Powers

Solution

先证明 \(\sum\limits_{i=1}^{n} i^k\) 是一个关于 \(n\)\(k+1\) 次多项式。

考虑数学归纳法,当 \(k=0\) 时,显然是一个一次多项式,

假设对于 \(x<k\),都有 \(\sum\limits_{i=1}^{n} i^x\) 是一个关于 \(n\)\(x+1\) 次多项式,下证 \(\sum\limits_{i=1}^{n} i^k\) 是一个关于 \(n\)\(k+1\) 次多项式。

考虑二项式定理,\((i+1)^{k+1}-i^{k+1} = \sum\limits_{x=0}^{k} \binom{k+1}{x}i^x\)

对于左边,\(\sum\limits_{i=1}^n (i+1)^{k+1}-i^{k+1} = (n+1)^{k+1}-1\)

对于右边,\(\sum\limits_{i=1}^n \sum\limits_{x=0}^{k} \binom{k+1}{x}i^x = \sum\limits_{x=0}^{k}\binom{k+1}{x}\sum\limits_{i=1}^{n}i^x\)

\(g_x(n) = \sum\limits_{i=1}^{n} i^x\),则 \((n+1)^{k+1}-1 = \sum\limits_{x=0}^{k}\binom{k+1}{x}g_x(n)\)

把右边第 \(k\) 项提出来,\(g_k(n) = \frac{(n+1)^{k+1}-1 - \sum\limits_{x=0}^{k-1}\binom{k+1}{x}g_x(n)}{k+1}\)

那么右边的最高次项显然是 \((n+1)^{k+1}\),因此得证。

那么我们知道,确定一个 \(k+1\) 次多项式需要 \(k+2\) 个点,如果随便取点再用普通拉插 \(O(k^2)\) 肯定过不了,于是我们取 \(x = 1,2,\dots,k+2\) 这几个点横坐标连续的点即可,纵坐标就是 \(y = \sum\limits_{i=1}^{x} i^{k}\)

\(f(x) = x^{k}\) 是一个完全积性函数,直接线性筛即可做到 \(O(k)\)

这样一来总复杂度就是 \(O(k)\) 了。

Code
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;

const int N = 1e6 + 5, MOD = 1e9 + 7;

int n, k, prime[N], vis[N], cnt, p[N], y[N], pre[N], suf[N], fact[N], inv[N];

int fpow(int a, int b) {
    int ans = 1;
    for (; b; b >>= 1) {
        if (b & 1) ans = (LL)ans * a % MOD;
        a = (LL)a * a % MOD;
    }
    return ans;
}

int main() {
    cin >> n >> k;
    p[1] = 1;
    for (int i = 2; i <= k + 2; i++) {
        if (!vis[i]) {
            prime[++cnt] = i;
            vis[i] = i;
            p[i] = fpow(i, k);
        }
        for (int j = 1; j <= cnt && prime[j] <= (k + 2) / i; j++) {
            vis[prime[j] * i] = prime[j];
            p[prime[j] * i] = (LL)p[i] * p[prime[j]] % MOD;
            if (i % prime[j] == 0) break;
        }
    }
    pre[0] = fact[0] = 1;
    for (int i = 1; i <= k + 2; i++) {
        y[i] = (y[i - 1] + p[i]) % MOD;
        pre[i] = (LL)pre[i - 1] * (n - i) % MOD;
        fact[i] = (LL)fact[i - 1] * i % MOD;
    }
    inv[k + 2] = fpow(fact[k + 2], MOD - 2);
    suf[k + 3] = 1;
    for (int i = k + 2; i >= 1; i--) suf[i] = (LL)suf[i + 1] * (n - i) % MOD;
    for (int i = k + 1; i >= 0; i--) inv[i] = (LL)inv[i + 1] * (i + 1) % MOD;
    int ans = 0;
    for (int i = 1; i <= k + 2; i++) {
        int op;
        if ((k + 2 - i) & 1) op = -1;
        else op = 1;
        ans = (ans + (LL)op * y[i] * inv[i - 1] % MOD * inv[k + 2 - i] % MOD * pre[i - 1] % MOD * suf[i + 1] % MOD) % MOD;
    }
    cout << (ans + MOD) % MOD << '\n';
    return 0;
}

拉插求系数

观察拉插基本形式的式子,\(\sum\limits_{i=1}^{n} y_i \prod\limits_{j\neq i}\frac{(x-x_j)}{(x_i-x_j)}\),(假设题目给了 \(n\) 个点)

\(y_i\) 是乘在外面的常数,可以最后再乘。

\(g(i) = \prod_{j\neq i} x_i-x_j\)\(O(n^2)\) 预处理,这样分母部分也解决了。

对于分子,\(i\neq j\) 有点难处理,我们把它变为 \(\frac{\prod\limits_j (x-x_j)}{x-x_i}\),记多项式 \(F(x) = \prod_j (x-x_j)\),于是我们要做的就是预处理 \(F(x)\) 的系数,然后对于每个 \(i\) 计算 \(\frac{F(x)}{x-x_i}\) 的系数。

对于多项式连乘,由于这些做乘数的每个多项式只有两项,那么可以设 \(f(i,j)\) 表示前 \(i\) 个多项式乘完后,\(j\) 次项的系数,\(f(i,j)=f(i-1,j-1)+f(i-1,j)\times (-x_i)\),意义就是要么把第 \(i\) 项的 \(x\) 乘过去,要么把 \(-x_i\) 乘过去。复杂度 \(O(n^2)\)\(i\) 这一维可以滚动掉。

这里不仅是写在拉插里,只要是求许多多项式连乘的都可以考虑一下这个求 \(f\) 的过程。

记多项式 \(P(x) = \frac{F(x)}{x-x_i}\),设 \(p(i)\) 表示 \(P(x)\)\(i\) 次项的系数,

\(F(x) = f(n)x^n + f(n-1)x^{n-1}+\dots + f(0)x^0\)\(P(x) = p(n-1)x^{n-1} + p(n-2)x^{n-2}+\dots + p(0)x^0\)

由于 \(F(x) = (x-x_i) P(x)\),所以对于每个 \(0\le j<n\),有 \(-x_ip(j+1)x^{j+1} + p(j)x^{j+1}=f(j+1)x^{j+1}\)

得到 \(p(j) = f(j+1) + x_ip(j+1)\)

由于 \(p(n)=0\),所以 \(p(n-1)=f(n)\),于是可以 \(O(n)\) 倒推求出 \(P(x)\) 的系数,

最后只需要枚举 \(1\le i\le n\),每次 \(O(n)\) 求出 \(p\) 之后,将 \(p(j)y_i\frac{1}{g(i)}\) 累加到 \(h(j)\) 中即可。

最终复杂度 \(O(n^2)\)

LGP4781 【模板】拉格朗日插值

拉插求系数的写法。

Code
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;

const int N = 2e3 + 5, MOD = 998244353;

int n, k, x[N], y[N];

int fpow(int a, int b) {
    int ans = 1;
    for (; b; b >>= 1) {
        if (b & 1) ans = (LL)ans * a % MOD;
        a = (LL)a * a % MOD;
    }
    return ans;
}

namespace LAGR { // 拉格朗日插值
int f[N], g[N], h[N];
void lagr() { 
    f[0] = 1;
    for (int i = 1; i <= n; i++) {
        for (int j = i; j >= 0; j--) {
            f[j] = (f[j - 1] + (LL)f[j] * (MOD - x[i])) % MOD;
        }
    }
    for (int i = 1; i <= n; i++) {
        g[i] = 1;
        for (int j = 1; j <= n; j++) {
            if (j == i) continue;
            g[i] = (LL)g[i] * (x[i] - x[j] + MOD) % MOD;
        }
    }
    for (int i = 1; i <= n; i++) {
        int t = (LL)y[i] * fpow(g[i], MOD - 2) % MOD, k = f[n];
        for (int j = n - 1; j >= 0; j--) {
            h[j] = (h[j] + (LL)k * t) % MOD;
            k = (f[j] + (LL)k * x[i]) % MOD;
        }
    }
}    
} using LAGR::lagr, LAGR::h;

int main() {
    cin >> n >> k;
    for (int i = 1; i <= n; i++) {
        cin >> x[i] >> y[i];
    }

    // 普通拉插
    // int ans = 0;
    // for (int i = 1; i <= n; i++) {
    //     int res1 = 1, res2 = 1;
    //     for (int j = 1; j <= n; j++) {
    //         if (j == i) continue;
    //         res1 = (LL)res1 * (k - x[j]) % MOD;
    //         res2 = (LL)res2 * (x[i] - x[j]) % MOD;
    //     }
    //     ans = (ans + (LL)y[i] * res1 % MOD * fpow(res2, MOD - 2) % MOD) % MOD;
    // }
    // cout << (ans + MOD) % MOD << '\n';

    // 求系数的拉插
    int ans = 0, p = 1;
    lagr();
    for (int i : h) {
        ans = (ans + (LL)i * p) % MOD;
        p = (LL)p * k % MOD;
    }
    cout << ans << '\n';

    return 0;
}

P7116 [NOIP2020] 微信步数

Solution

每一维独立,假设当前考虑第 \(j\) 维。

反过来考虑每一步有多少点存活(包括第 \(0\) 步),容易得到暴力做法。

暴力
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;

const int N = 5e5 + 5, MOD = 1e9 + 7;

int n, k, w[N], c[N], d[N], mx[N], mn[N], sum[N];

int main() {
    cin >> n >> k;
    int ans = 1;
    for (int i = 1; i <= k; i++) {
        cin >> w[i];
        ans = (LL)ans * w[i] % MOD;
    }
    for (int i = 1; i <= n; i++) {
        cin >> c[i] >> d[i];
    }
    int cur = 1;
    while (1) {
        if (cur > n) cur = 1;
        sum[c[cur]] += d[cur];
        mx[c[cur]] = max(mx[c[cur]], sum[c[cur]]);
        mn[c[cur]] = min(mn[c[cur]], sum[c[cur]]);
        int res = 1;
        for (int i = 1; i <= k; i++) res = (LL)res * max(0, w[i] - mx[i] + mn[i]) % MOD;
        ans = (ans + res) % MOD;
        if (res == 0) break;
        if (cur == n) {
            bool flag = false;
            for (int i = 1; i <= k; i++) if (sum[i] != 0) flag = true;
            if (!flag) {
                ans = -1;
                break;
            }
        }
        cur++;
    }
    cout << ans << '\n';
    return 0;
}

\(mx_{i,j}\)\(mn_{i,j}\) 表示到第 \(i\) 步的历史最大和最小位移。

那么对于每一步,当前死亡的点(走出边界的点)数为 \(mx_{i,j}-mn_{i,j}\),因为位于 \([1,-mn_{i,j}]\)\([w_j - mx_{i,j}+1, w_j]\) 的点都死了。

\(sum_j\) 表示一轮过后的总位移,就是 \(\sum\limits_{c_i=j} d_i\)

那么对于 \(1\le i\le n\)\(mx_{pn+i} = \max\{mx_{pn}, mx_{(p-1)n + i,j}+sum_{j}\}\)\(p\in Z\)(1),也就是从上一轮第 \(i\) 步到这一轮的第 \(i\) 步,\(mx\) 偏移了 \(\max\{0, sum_j\}\)\(mn\) 同理。

\(dmx_{i,j}\) 表示从第二轮起 \(mx_{i,j}\) 相对于上一轮的最后一步的 \(mx\) 的变化量,即 \(dmx_{i,j} = \max\{0, mx_{i,j}-mx_{i-n,j}\}\),那么 \(dmx_{i,j}-dmn_{i,j}\) 就表示死亡点数的变化量(多死了几个点)。

如果我们代入上面的 \((1)\) 式,把 \(dmx_{i,j}\)\(dmx_{i+n,j}\) 都表示出来,发现 \(dmx_{i,j}=dmx_{i+n,j}\),也就是说变化量都是相等的,于是我们只需要记录第一轮到第二轮的变化量。\(dmn\) 同理。

于是 \(mx\)\(mn\) 只需要记录第一轮的即可。

\(rem_j = w_j-(mx_{n,j} - mn_{n,j})\),也就是第一轮结束后剩余存活人数。

我们先用暴力的方法将第一轮的贡献求出来,并判是否有解,后面从第二轮开始算。

容易表示出第 \(x+2\) 轮第 \(i\) 步的存活人数,\(rem_j - x(dmx_{n,j}-dmn_{n,j})-(dmx_{i,j}-dmn_{i,j})\)

算出 \(x\) 的最大值 \(tot = \min \frac{\max\{0,rem_j - (dmx_{i,j}-dmn_{i,j})\}}{dmx_{n,j}-dmn_{n,j}}\)

于是可以分开每一步考虑,那么从第二轮起的总答案就是 \(\sum\limits_{i=1}^{n} \sum\limits_{x=0}^{tot} \prod\limits_{j=1}^{k} rem_j - x(dmx_{n,j}-dmn_{n,j})-(dmx_{i,j}-dmn_{i,j})\)

\(\prod\) 这里直接用多项式连乘的 Trick \(O(k^2)\) 求出系数,记 \(i\) 次项的系数为 \(f_i\)

然后对于 \(\sum\limits_{x=0}^{tot}\),改变的只是 \(x\) 这一未知量,相当于要求 \(\sum\limits_{t=0}^{k}f(t)\sum\limits_{x=0}^{tot} x^t\),(这里定义 \(0^0=1\)),于是这一部分直接拉插求即可。

总复杂度 \(O(nk^2)\)

Code
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;
#define Mod(x) (x >= MOD ? x - MOD : x)

const int N = 5e5 + 5, MOD = 1e9 + 7, INF = 2e9;

int n, k, w[15], c[N], d[N];
int mx[N][15], mn[N][15]; // mx[i][j]: 表示第一轮时第 j 维第 i 步的历史最大位移, mn[i][j] 是最小
int dmx[N][15], dmn[N][15]; // 从第二轮起, 每轮第 j 维第 i 步的历史最大位移 相对于 上一轮结束后的历史最大位移 的 变化量 
int sum[15]; // 一轮的总位移
int rem[15]; // 第一轮结束后剩余存活的点数
int del[N][15]; // 第 j 维第 i 步 相对于 上一轮结束时 新增死亡的点数
int pw[15][15], fact[15], inv[15];
// int x[15], y[15], f[15];

int fpow(int a, int b) {
    int ans = 1;
    for (; b; b >>= 1) {
        if (b & 1) ans = (LL)ans * a % MOD;
        a = (LL)a * a % MOD;
    }
    return ans;
}

namespace LAGR {
int f[20], g[20], pre[20], suf[20];

void lagr(int n, int* x, int* y, int* h) {
    for (int i = 0; i <= n; i++) f[i] = h[i] = 0;
    f[0] = 1;
    for (int i = 1; i <= n; i++) {
        for (int j = i; j >= 0; j--) {
            f[j] = ((LL)f[j] * (MOD - x[i]) + (j > 0 ? f[j - 1] : 0)) % MOD;
        }
    }
    // for (int i = 1; i <= n; i++) {
    //     g[i] = 1;
    //     for (int j = 1; j <= n; j++) {
    //         if (j == i) continue;
    //         g[i] = (LL)g[i] * (x[i] - x[j]) % MOD;
    //     }
    //     g[i] = (g[i] + MOD) % MOD;
    // }
    for (int i = 1; i <= n; i++) {
        // int t = (LL)y[i] * fpow(g[i], MOD - 2) % MOD, k = f[n]; // 普通写法, 但是常数太大了, log(1e9) = 2k
        // 发现 x[i] 连续, 这样优化了一下常数后可以过
        int t = (LL)y[i] * inv[i - 1] % MOD * inv[n - i] % MOD * ((n - i) & 1 ? MOD - 1 : 1) % MOD, k = f[n];
        for (int j = n - 1; j >= 0; j--) {
            h[j] = (h[j] + (LL)k * t) % MOD;
            k = (f[j] + (LL)k * x[i]) % MOD;
        }
    }
}

int calc(int n, int m) { // 计算 \sum_{i = 1}^{n} i^m
    pre[0] = 1;
    for (int i = 1; i <= m + 2; i++) pre[i] = (LL)pre[i - 1] * (n - i) % MOD;
    suf[m + 3] = 1;
    for (int i = m + 2; i >= 1; i--) suf[i] = (LL)suf[i + 1] * (n - i) % MOD;
    int y = 0, ans = 0, op = ((m + 1) % 2) ? -1 : 1;
    for (int i = 1; i <= m + 2; i++) {
        y = Mod(y + pw[i][m]);
        ans = (ans + (LL)op * pre[i - 1] % MOD * suf[i + 1] % MOD * inv[i - 1] % MOD * inv[m + 2 - i] % MOD * y) % MOD;
        op = -op;
    }
    return (ans + MOD) % MOD;
}
} using LAGR::lagr, LAGR::calc;

int main() {
    // freopen("walk3.in", "r", stdin);
    // freopen("P7116_7.in", "r", stdin);
    // freopen("P7116_14.in", "r", stdin);
    // freopen("walk.out", "w", stdout);
    
    ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);

    cin >> n >> k;
    int ans = 1;
    for (int i = 1; i <= k; i++) {
        cin >> w[i];
        ans = (LL)ans * w[i] % MOD;
    }
    for (int i = 1; i <= n; i++) cin >> c[i] >> d[i];

    // 第一轮
    int res = 1;
    for (int i = 1; i <= n; i++) {
        sum[c[i]] += d[i];
        for (int j = 1; j <= k; j++) mx[i][j] = mx[i - 1][j], mn[i][j] = mn[i - 1][j];
        mx[i][c[i]] = max(mx[i][c[i]], sum[c[i]]);
        mn[i][c[i]] = min(mn[i][c[i]], sum[c[i]]);
        res = 1;
        for (int j = 1; j <= k; j++) res = (LL)res * max(0, w[j] - mx[i][j] + mn[i][j]) % MOD;
        ans = (ans + res) % MOD;
    }
    if (res == 0) {
        cout << ans << '\n';
        return 0;
    } else { // 如果回到原点了 并且 还有人活着 则 无解
        bool flag = false;
        for (int i = 1; i <= k; i++) if (sum[i] != 0) flag = true;
        if (!flag) {
            cout << -1 << '\n';
            return 0;
        }
    }

    // cout << "First: " << ans << '\n';

    for (int i = 1; i <= k; i++) {
        rem[i] = max(0, w[i] - mx[n][i] + mn[n][i]);
    }
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= k; j++) {
            dmx[i][j] = max(0, mx[i][j] + sum[j] - mx[n][j]);
            dmn[i][j] = min(0, mn[i][j] + sum[j] - mn[n][j]);
            del[i][j] = dmx[i][j] - dmn[i][j];
        }
    }

    fact[0] = 1;
    for (int i = 1; i <= k + 2; i++) {
        for (int j = 0; j <= k; j++) pw[i][j] = fpow(i, j);
        fact[i] = (LL)fact[i - 1] * i % MOD;
    }
    inv[k + 2] = fpow(fact[k + 2], MOD - 2);
    for (int i = k + 1; i >= 0; i--) inv[i] = (LL)inv[i + 1] * (i + 1) % MOD;

    // 暴力
    // for (int i = 1; i <= n; i++) {
    //     int tot = INF;
    //     for (int j = 1; j <= k; j++) {
    //         if (del[n][j] == 0) continue;
    //         tot = min(tot, max(0, rem[j] - del[i][j]) / del[n][j]);
    //     }
    //     for (int x = 0; x <= tot; x++) {
    //         res = 1;
    //         for (int j = 1; j <= k; j++) {
    //             res = (LL)res * (max(0, rem[j] - del[i][j]) - x * del[n][j]) % MOD;
    //         }
    //         if (res == 0) break;
    //         ans = (ans + res) % MOD;
    //     }
    // }

    int lst = -1;
    vector<int> rec(k + 1), f(k + 1, 0);
    for (int i = 1; i <= n; i++) {
        // cout << "i: " << i << '\n';
        int tot = INF;
        for (int j = 1; j <= k; j++) {
            if (del[n][j] == 0) continue;
            tot = min(tot, max(0, rem[j] - del[i][j]) / del[n][j]);
        }

        if (tot <= k) {
            for (int j = 0; j <= tot; j++) {
                res = 1;
                for (int t = 1; t <= k; t++) {
                    res = (LL)res * (max(0, rem[t] - del[i][t]) - (LL)j * del[n][t]) % MOD;
                }
                if (res == 0) break;
                ans = (ans + res) % MOD;
            }
            // cout << "ans1: " << ans << '\n';
            continue;
        }

        // 大常数写法
        // for (int j = 1; j <= k + 1; j++) {
        //     x[j] = j;
        //     y[j] = 1;
        //     for (int t = 1; t <= k; t++) {
        //         y[j] = (LL)y[j] * (rem[t] - del[i][t] - (LL)j * del[n][t]) % MOD;
        //     }
        // }
        // lagr(k + 1, x, y, f);

        // 但是这个多项式是好几个 一次二项多项式 的乘积, 那么可以直接 DP, 也就相当于直接用 lagr() 前面求 f 的部分即可
        // 这么写常数很小, 比之前快了 1s
        for (int j = 0; j <= k; j++) f[j] = 0;
        f[0] = 1;
        for (int j = 1; j <= k; j++) {
            int x = rem[j] - del[i][j];
            for (int t = j; t >= 0; t--) {
                f[t] = (LL)f[t] * x % MOD;
                if (t > 0) f[t] = (f[t] + (LL)f[t - 1] * (-del[n][j])) % MOD;
            }
        }
        for (int j = 0; j <= k; j++) f[j] = (f[j] + MOD) % MOD;

        ans = (ans + (LL)f[0] * (tot + 1) % MOD) % MOD;
        if (tot != lst) { // 剪枝, 不重复计算, 也快了将近 1s
            lst = tot;
            for (int j = 1; j <= k; j++) rec[j] = calc(tot, j);
        }
        for (int j = 1; j <= k; j++) {
            ans = (ans + (LL)f[j] * rec[j]) % MOD;
        }
        // cout << "ans2: " << ans << '\n';
    }

    cout << ans << '\n';

    cerr << "Time: " << 1000 * clock() / CLOCKS_PER_SEC << '\n';
    return 0;
}

拉插优化DP

CF995F Cowmpany Cowmpensation

Solution

引理:若 \(F(x)\)\(n\) 次多项式,则 \(G(x)=\sum\limits_{y=1}^{x} F(y)\)\(n+1\) 次多项式。

证明:

\(F(x) = f_nx^n + \dots + f_0\),则 \(G(x) = f_n\sum\limits_{y=1}^{x} y^n + \dots + f_0x\)

而前面已经证过,\(\sum\limits_{y=1}^{x} y^n\) 是一个 \(n+1\) 次多项式,于是 \(G(x)\) 是一个 \(n+1\) 次多项式。

\(f(u,i)\) 表示 \(u\) 选择 \(i\),以 \(u\) 为根的子树的方案数。

转移方程有,\(f(u,i)=\prod\limits_{v} \sum\limits_{j=1}^{i} f(v,i)\)

因为当 \(u\) 为叶子时,\(f(u,i)=1\),是一个关于 \(i\)\(0\) 次多项式,

我们归纳假设 \(u\) 的儿子 \(v\) 满足 \(f(v,i)\) 是一个关于 \(i\)\(size_v-1\) 次多项式,

于是根据上面的引理以及转移方程,\(f(u,i)\) 是一个关于 \(i\)\(size_u-1\) 次多项式。

\(g(u,i)=\sum\limits_{j=1}^{i}f(u,i)\),显然 \(g(u,i)\) 是关于 \(i\)\(size_u\) 次多项式。

答案要我们求 \(g(1,d)\),因此答案是一个 \(n\) 次多项式,那么我们只需要 \(O(n^2)\) 求出 \(g(1,[1,n+1])\),直接用 给定横坐标连续 的方法插值即可。当然朴素的插值也是可以过的。

Code
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;

const int N = 3005, MOD = 1e9 + 7;

int n, m;
vector<int> G[N];
int a[N], dp[N][N], sum[N][N];
int fact[N], inv[N];
int pre[N], suf[N];

int fpow(int a, int b) {
    int ans = 1;
    for (; b; b >>= 1) {
        if (b & 1) ans = (LL)ans * a % MOD;
        a = (LL)a * a % MOD;
    }
    return ans;
}

void init() {
    fact[0] = 1;
    for (int i = 1; i <= n + 1; i++) fact[i] = (LL)fact[i - 1] * i % MOD;
    inv[n + 1] = fpow(fact[n + 1], MOD - 2);
    for (int i = n; i >= 0; i--) inv[i] = (LL)inv[i + 1] * (i + 1) % MOD;
}

void dfs(int u) {
    for (int v : G[u]) dfs(v);
    for (int i = 1; i <= n + 1; i++) {
        dp[u][i] = 1;
        for (int v : G[u]) {
            dp[u][i] = (LL)dp[u][i] * sum[v][i] % MOD;
        }
        sum[u][i] = (sum[u][i - 1] + dp[u][i]) % MOD;
    }
}

int main() {
    cin >> n >> m;
    init();
    int x;
    for (int i = 2; i <= n; i++) {
        cin >> x;
        G[x].push_back(i);
    }
    dfs(1);
    pre[0] = 1;
    for (int i = 1; i <= n + 1; i++) pre[i] = (LL)pre[i - 1] * (m - i) % MOD;
    suf[n + 2] = 1;
    for (int i = n + 1; i >= 1; i--) suf[i] = (LL)suf[i + 1] * (m - i) % MOD;
    int ans = 0, op = (n % 2 ? -1 : 1);
    for (int i = 1; i <= n + 1; i++) {
        ans = (ans + (LL)op * sum[1][i] * inv[i - 1] % MOD * inv[n + 1 - i] % MOD * pre[i - 1] % MOD * suf[i + 1]) % MOD;
        op = -op;
    }
    cout << (ans + MOD) % MOD << '\n';
    return 0;
}

// dp[1][x] 是关于 x 的 n-1 次多项式
// sum[1][x] 是关于 x 的 n 次多项式

LGP8290 [省选联考 2022] 填树

Solution

怎么一开始就想错了,设的是 \(dp(u,l,r)\),看着就没优化的前途。T_T

枚举最小值 \(x\),把 \([l_u,r_u]\) 变成 \([l_u,r_u]\cap [x,x+k]\),再在树上进行 DP。

但是这样求的方案的最小值可能 \(>x\),考虑容斥,对于每个 \(x\),再 DP 求一遍 \([x+1,x+k]\),用 \([x,x+k]\) 减去即可。

\(f_1(u)\) 表示以 \(u\) 为端点,向 \(u\) 子树内走的合法链的方案数,\(f_2(u)\) 表示 \(u\) 的子树内合法链的方案数。

\(g_1(u)\) 表示以 \(u\) 为端点,向 \(u\) 子树内走的合法链的权值和,\(g_2(u)\) 表示 \(u\) 的子树内合法链的权值和。

假设当前转移到 \(u\),设 \(u\) 能取到的范围为 \([L_u,R_u] = [l_u,r_u]\cap [x,x+k]\)\(len=R_u-L_u+1\)\(sum=\frac{(L_u+R_u)\cdot len}{2}\)

\(f_1(u) = len + \sum\limits_{v\in son_u} f_1(v)\cdot len\)

\(g_1(u) = sum + \sum\limits_{v\in son_u} g_1(v)\cdot len + sum\cdot f_1(v)\)

\(f_2(u) = len + \sum\limits_{v\in son_u} f_1(v)\cdot (1+\sum\limits_{pre}f_1(pre))\cdot len + \sum\limits_{v\in son_u} f_2(v)\)

\(g_2(u) = sum + \sum\limits_{v\in son_u} (g_1(v)\cdot len + sum\cdot f_1(v))\cdot (1+\sum\limits_{pre}f_1(pre)) + (\sum\limits_{pre} g_1(pre)) \cdot f_1(v) \cdot len + \sum\limits_{v\in son_u} g_2(v)\)

于是可以得到一个 \(O(nV)\) 的暴力。\(V\) 是值域。可以获得 40 分。

Code
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;

const int N = 205, MOD = 1e9 + 7, INF = 2e9;

int inv2;
int n, k;
int l[N], r[N], L[N], R[N];
vector<int> G[N];
int f1[N], f2[N], g1[N], g2[N];

int fpow(int a, int b) {
    int ans = 1;
    for (; b; b >>= 1) {
        if (b & 1) ans = (LL)ans * a % MOD;
        a = (LL)a * a % MOD;
    }
    return ans;
}

void dfs(int u, int pre) {
    int len = max(0, R[u] - L[u] + 1), sum = (LL)(R[u] + L[u]) * len % MOD * inv2 % MOD;
    int lstf = 0, lstg = 0;
    f1[u] = f2[u] = len;
    g1[u] = g2[u] = sum;
    for (int v : G[u]) {
        if (v == pre) continue;
        dfs(v, u);
        f1[u] = (f1[u] + (LL)f1[v] * len) % MOD;
        g1[u] = (g1[u] + (LL)g1[v] * len + (LL)sum * f1[v]) % MOD;
        f2[u] = (f2[u] + (LL)f1[v] * (1 + lstf) % MOD * len + f2[v]) % MOD;
        g2[u] = (g2[u] + ((LL)g1[v] * len + (LL)sum * f1[v]) % MOD * (1 + lstf) + (LL)lstg * f1[v] % MOD * len + g2[v]) % MOD;
        lstf = (lstf + f1[v]) % MOD;
        lstg = (lstg + g1[v]) % MOD;
    }
}

PII work(int x, int y) {
    if (x > y) return {0, 0};
    for (int i = 1; i <= n; i++) {
        L[i] = max(l[i], x);
        R[i] = min(r[i], y);
    }
    dfs(1, 0);
    return {f2[1], g2[1]};
}

int main() {
    inv2 = fpow(2, MOD - 2);
    cin >> n >> k;
    int mxr = 0, mnl = INF;
    for (int i = 1; i <= n; i++) {
        cin >> l[i] >> r[i];
        mxr = max(mxr, r[i]);
        mnl = min(mnl, l[i]);
    }
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        G[u].push_back(v);
        G[v].push_back(u);
    }
    int ans1 = 0, ans2 = 0;
    for (int x = mnl; x <= mxr; x++) {
        auto i = work(x, x + k), j = work(x + 1, x + k);
        ans1 = (ans1 + (i.first - j.first + MOD) % MOD) % MOD;
        ans2 = (ans2 + (i.second - j.second + MOD) % MOD) % MOD;
    }
    cout << ans1 << '\n' << ans2 << '\n';
    return 0;
}

发现 \([x,x+k]\) 在值域上滑动时,\(len\) 可以表示为关于 \(x\) 的一次多项式,\(sum\) 是关于 \(x\) 的二次多项式。

\(x\)\(x+k\) 从某个 \(l_i-1\) 滑到 \(l_i\),或从某个 \(r_i\) 滑到 \(r_i+1\) 的时候,\(len\)\(sum\) 的多项式才会发生改变,这样改变的次数只有 \(O(n)\) 次,

而其余大部分时候这些多项式都是不变的,又因为树的形态不变,转移方程不变,那么答案的多项式也不变,只有自变量 \(x\) 在改变,

于是我们只需要求出答案多项式的次数,然后再每个不会改变多项式的段里,暴力算出前若干项的 \(f_2()\)\(g_2()\) 的前缀和,那么只需要拉插求这一段末尾的前缀和即可。

\(d_u\) 表示 \(u\) 到最远的叶子节点的路径上的点数,

根据 \(len\)\(sum\) 的次数,结合转移方程,容易得到 \(\deg f_1(u) = d_u\)\(\deg g_1(u) = d_u+1\),那么容易用一遍 \(dfs\) 求出 \(\deg f_2(1)\),根据转移方程也可以推算出,\(\deg g_2(1)=\deg f_2(1)+1\)

Code
// deg 即为 f1(1) 的次数
int deg = 0;
int get_dep(int u, int pre) {
    int mx1 = 0, mx2 = 0, x;
    for (int v : G[u]) {
        if (v == pre) continue;
        x = get_dep(v, u);
        if (x > mx1) {
            mx2 = mx1;
            mx1 = x;
        } else if (x > mx2) mx2 = x;
    }
    deg = max(deg, mx1 + mx2 + 1);
    return mx1 + 1;
}

算出 \(g_2(1)\) 的次数了之后,由于要求的是前缀和,所以次数还要加一,

因此每一段取前 \(deg + 3\) 个点暴力计算前缀和,然后用给定横坐标连续的拉插求解,

这里为了方便,\(\sum f_2(1)\)\(\sum g_2(1)\) 都统一插 \(deg+3\) 个点即可,因为插多了点不会影响答案。

总复杂度 \(O(n^3)\)

按照我个人理解,\([x,x+1]\) 的答案和 \([x+1,x+k]\) 的答案应该分开计算才对,因为某一段里面 \([x,x+k]\) 的滑动不影响多项式,但是 \([x+1,x+k]\) 可能会在最后跨出去一格,按理来说应该是会影响的。但是直接合在一起计算也是能过的。这里我按照自己的理解来写了。

Code
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;

const int N = 205, MOD = 1e9 + 7, INF = 2e9;

int inv2;
int n, k;
int l[N], r[N], L[N], R[N];
vector<int> G[N];
int f1[N], f2[N], g1[N], g2[N];
int sum1[N], sum2[N];
int pre[N], suf[N];
int fact[N], inv[N];
int deg;

int fpow(int a, int b) {
    int ans = 1;
    for (; b; b >>= 1) {
        if (b & 1) ans = (LL)ans * a % MOD;
        a = (LL)a * a % MOD;
    }
    return ans;
}

void init() {
    inv2 = fpow(2, MOD - 2);
    int n = deg + 1;
    fact[0] = 1;
    for (int i = 1; i <= n; i++) fact[i] = (LL)fact[i - 1] * i % MOD;
    inv[n] = fpow(fact[n], MOD - 2);
    for (int i = n - 1; i >= 0; i--) inv[i] = (LL)inv[i + 1] * (i + 1) % MOD;
}

void dfs(int u, int pre) {
    int len = max(0, R[u] - L[u] + 1), sum = (LL)(R[u] + L[u]) * len % MOD * inv2 % MOD;
    int lstf = 0, lstg = 0;
    f1[u] = f2[u] = len;
    g1[u] = g2[u] = sum;
    for (int v : G[u]) {
        if (v == pre) continue;
        dfs(v, u);
        f1[u] = (f1[u] + (LL)f1[v] * len) % MOD;
        g1[u] = (g1[u] + (LL)g1[v] * len + (LL)sum * f1[v]) % MOD;
        f2[u] = (f2[u] + (LL)f1[v] * (1 + lstf) % MOD * len + f2[v]) % MOD;
        g2[u] = (g2[u] + ((LL)g1[v] * len + (LL)sum * f1[v]) % MOD * (1 + lstf) + (LL)lstg * f1[v] % MOD * len + g2[v]) % MOD;
        lstf = (lstf + f1[v]) % MOD;
        lstg = (lstg + g1[v]) % MOD;
    }
}

PII work(int x, int y) {
    if (x > y) return {0, 0};
    for (int i = 1; i <= n; i++) {
        L[i] = max(l[i], x);
        R[i] = min(r[i], y);
    }
    dfs(1, 0);
    return {f2[1], g2[1]};
}

PII lagr(int x) {
    int tot = deg + 1;
    pre[0] = 1;
    for (int i = 1; i <= tot; i++) pre[i] = (LL)pre[i - 1] * (x - i) % MOD;
    suf[tot + 1] = 1;
    for (int i = tot; i >= 1; i--) suf[i] = (LL)suf[i + 1] * (x - i) % MOD;
    int ans1 = 0, ans2 = 0, op = ((tot - 1) & 1) ? -1 : 1;
    for (int i = 1; i <= tot; i++) {
        ans1 = (ans1 + (LL)op * sum1[i] * inv[i - 1] % MOD * inv[tot - i] % MOD * pre[i - 1] % MOD * suf[i + 1]) % MOD;
        ans2 = (ans2 + (LL)op * sum2[i] * inv[i - 1] % MOD * inv[tot - i] % MOD * pre[i - 1] % MOD * suf[i + 1]) % MOD;
        op = -op;
    }
    ans1 = (ans1 + MOD) % MOD;
    ans2 = (ans2 + MOD) % MOD;
    return {ans1, ans2};
}

int get_dep(int u, int pre) {
    int mx1 = 0, mx2 = 0, x;
    for (int v : G[u]) {
        if (v == pre) continue;
        x = get_dep(v, u);
        if (x > mx1) {
            mx2 = mx1;
            mx1 = x;
        } else if (x > mx2) mx2 = x;
    }
    deg = max(deg, mx1 + mx2 + 1);
    return mx1 + 1;
}

int main() {
    // freopen("P8290_3.in", "r", stdin);
    // freopen("tree.out", "w", stdout);

    ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);

    cin >> n >> k;
    
    int mnl = INF;
    for (int i = 1; i <= n; i++) {
        cin >> l[i] >> r[i];
        mnl = min(mnl, l[i]);
    }
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        G[u].push_back(v);
        G[v].push_back(u);
    }

    get_dep(1, 0);
    deg += 2;

    init();

    vector<int> pos;
    for (int i = 1; i <= n; i++) {
        pos.push_back(l[i]);
        if (l[i] - k >= mnl) pos.push_back(l[i] - k);
        pos.push_back(r[i] + 1);
        if (r[i] + 1 - k >= mnl) pos.push_back(r[i] + 1 - k);
    }
    sort(pos.begin(), pos.end());
    int m = unique(pos.begin(), pos.end()) - pos.begin();
    int ans1 = 0, ans2 = 0;
    for (int i = 0; i + 1 < m; i++) {
        int len = pos[i + 1] - pos[i];
        if (len <= deg + 1) {
            for (int x = pos[i]; x < pos[i + 1]; x++) {
                auto a = work(x, x + k);
                ans1 = (ans1 + a.first) % MOD;
                ans2 = (ans2 + a.second) % MOD;
            }
        } else {
            for (int x = pos[i]; x <= pos[i] + deg; x++) {
                auto a = work(x, x + k);
                sum1[x - pos[i] + 1] = (sum1[x - pos[i]] + a.first) % MOD;
                sum2[x - pos[i] + 1] = (sum2[x - pos[i]] + a.second) % MOD;
            }
            auto a = lagr(pos[i + 1] - pos[i]);
            ans1 = (ans1 + a.first) % MOD;
            ans2 = (ans2 + a.second) % MOD;
        }
    }

    pos.clear();
    pos.push_back(mnl + 1);
    for (int i = 1; i <= n; i++) {
        if (l[i] >= mnl + 1) pos.push_back(l[i]);
        if (l[i] - k + 1 >= mnl + 1) pos.push_back(l[i] - k + 1);
        pos.push_back(r[i] + 1);
        if (r[i] - k + 2 >= mnl + 1) pos.push_back(r[i] - k + 2);
    }
    sort(pos.begin(), pos.end());
    m = unique(pos.begin(), pos.end()) - pos.begin();
    for (int i = 0; i + 1 < m; i++) {
        int len = pos[i + 1] - pos[i];
        if (len <= deg + 1) {
            for (int x = pos[i]; x < pos[i + 1]; x++) {
                auto a = work(x, x + k - 1);
                ans1 = (ans1 - a.first + MOD) % MOD;
                ans2 = (ans2 - a.second + MOD) % MOD;
            }
        } else {
            for (int x = pos[i]; x <= pos[i] + deg; x++) {
                auto a = work(x, x + k - 1);
                sum1[x - pos[i] + 1] = (sum1[x - pos[i]] + a.first) % MOD;
                sum2[x - pos[i] + 1] = (sum2[x - pos[i]] + a.second) % MOD;
            }
            auto a = lagr(pos[i + 1] - pos[i]);
            ans1 = (ans1 - a.first + MOD) % MOD;
            ans2 = (ans2 - a.second + MOD) % MOD;
        }
    }

    cout << ans1 << '\n' << ans2 << '\n';
    return 0;
}

LGP5469 [NOI2019] 机器人

Solution

这道题有个很好的性质。对于一个序列来说,以最右边的最大值为分界点,这个位置可以走到整个序列的头和尾,并且左右两边的区间不会互相走到,那么左右两边可以分别归纳为两个子问题。于是我们根据这个来设计 DP。

\(f(l,r,x)\) 表示区间 \([l,r]\) 最大值不超过 \(x\) 的方案数,直接枚举最右边的最大值转移,并且保证这个最大值分别与 \(l\)\(r\) 距离的差的绝对值不超过 \(2\)。前缀和优化后可以得到 \(O(n^2V)\) 的做法(枚举转移的点很少,可以算为常数)。

由于空间复杂度无法接受,因此只能获得 35 分。

#include <bits/stdc++.h>
using namespace std;
using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;

const int N = 305, MOD = 1e9 + 7, INF = 2e9;

int n, a[N], b[N];
int dp[N][N][N];

int main() {
    cin >> n;
    int mxb = 0, mna = INF;
    for (int i = 1; i <= n; i++) {
        cin >> a[i] >> b[i];
        mxb = max(mxb, b[i]);
        mna = min(mna, a[i]);
    }
    for (int i = 1; i <= n; i++) {
        for (int j = a[i]; j <= b[i]; j++) dp[i][i][j] = (dp[i][i][j - 1] + 1) % MOD;
        for (int j = b[i] + 1; j <= mxb; j++) dp[i][i][j] = dp[i][i][j - 1];
    }
    for (int i = 0; i <= n; i++) for (int j = 0; j <= mxb; j++) dp[i + 1][i][j] = 1;
    for (int len = 2; len <= n; len++) {
        for (int l = 1; l + len - 1 <= n; l++) {
            int r = l + len - 1;
            int mid = l + r >> 1;
            for (int x = mna; x <= mxb; x++) {
                int L = max(l, mid - (len & 1)), R = min(r, mid + 1);
                for (int k = L; k <= R; k++) {
                    if (a[k] <= x && x <= b[k]) {
                        dp[l][r][x] = (dp[l][r][x] + (LL)dp[l][k - 1][x] * dp[k + 1][r][x - 1]) % MOD;
                    }
                }
                dp[l][r][x] = (dp[l][r][x] + dp[l][r][x - 1]) % MOD;
            }
        }
    }
    cout << dp[1][n][mxb] << '\n';
    return 0;
}

由于每次枚举的转移点很少,于是通过以下代码打表发现,当 \(n\le 300\) 时,用到的区间个数 \(m\le 2220\)(打表真的很重要),

// 打表代码
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;

map<PII, bool> mp;
int n;

void calc(int l, int r) {
    if (l > r) return;
    mp[{l, r}] = true;
    int mid = l + r >> 1;
    int len = r - l + 1;
    int L = max(l, mid - (len & 1)), R = min(r, mid + 1);
    for (int i = L; i <= R; i++) {
        calc(l, i - 1);
        calc(i + 1, r);
    }
}

int main() {
    int mx = 0;
    cin >> n;
    for (int i = 1; i <= n; i++) {
        mp.clear();
        calc(1, i);
        mx = max(mx, (int)mp.size());
    }
    cout << mx << '\n';
    return 0;
}

那么我们预处理所有区间,按照长度升序排序,然后编号,设 \(f(i,x)\)\(i\) 个区间最大值不超过 \(x\) 的方案数,即可得到 \(O(mV)\) 的做法,获得 50 分。

#include <bits/stdc++.h>
using namespace std;
using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;

const int N = 305, MOD = 1e9 + 7, INF = 2e9;

int n, a[N], b[N];
int id[N][N];
int dp[2600][10005];
vector<PII> rng;
map<PII, bool> vis;

void calc(int l, int r) {
    if (l > r) return;
    if (!vis.count({l, r})) {
        rng.push_back({l, r});
        vis[{l, r}] = true;
    }
    int mid = l + r >> 1;
    int len = r - l + 1;
    int L = max(l, mid - (len & 1)), R = min(r, mid + 1);
    for (int i = L; i <= R; i++) {
        calc(l, i - 1);
        calc(i + 1, r);
    }
}

int main() {
    cin >> n;
    int mxb = 0, mna = INF;
    for (int i = 1; i <= n; i++) {
        cin >> a[i] >> b[i];
        mxb = max(mxb, b[i]);
        mna = min(mna, a[i]);
    }
    calc(1, n);
    sort(rng.begin(), rng.end(), [&](auto i, auto j) { return i.second - i.first < j.second - j.first; });
    for (int i = 0; i <= mxb; i++) dp[0][i] = 1;
    for (int i = 0; i < rng.size(); i++) {
        int l = rng[i].first, r = rng[i].second;
        id[l][r] = i + 1;
        for (int j = mna; j <= mxb; j++) {
            int len = r - l + 1, mid = l + r >> 1;
            int L = max(l, mid - (len & 1)), R = min(r, mid + 1);
            for (int k = L; k <= R; k++) {
                if (a[k] <= j && j <= b[k]) {
                    dp[i + 1][j] = (dp[i + 1][j] + (LL)dp[id[l][k - 1]][j] * dp[id[k + 1][r]][j - 1]) % MOD;
                }
            }
            dp[i + 1][j] = (dp[i + 1][j] + dp[i + 1][j - 1]) % MOD;
        }
    }
    cout << dp[rng.size()][mxb] << '\n';
    return 0;
}

由转移方程我们可以知道 \(f(i,x)\)\((l_i=r_i)\),是一个关于 \(x\) 的不超过一次多项式,但是它只有在 \([a_i,b_i]\) 的区间内才是连续的,换句话说,只有在这个区间内,多项式的系数才是不变的。

我们可以归纳地得到 \(f(i,x)\) 是一个关于 \(x\) 的不超过 \(r_i-l_i+1\) 次多项式,

由于两段连续的多项式相乘还是连续的,于是我们可以让 \(a_i\)\(b_i\) 把值域划分成 \(O(n)\) 段,每一段的 \(f\) 都是连续的多项式,然后就可以在值域上从小到大遍历每一段,假设是 \([c_j,d_j]\),先求出前 \(n+1\) 项的 \(f(i,x)\)\(1\le i\le m\),然后拉插求出 \(f(i,d_j)\),再用这个值求出 \([c_{j+1},d_{j+1}]\) 段的答案,一直递推下去,就可以求出 \(f(i,mxb)\) 的答案。

由于拉插时给定的点的横坐标连续,于是可以 \(O(n)\) 求解,总复杂度 \(O(n^2m)\)。本题卡常。

#include <bits/stdc++.h>
using namespace std;
using LL = long long;
using PII = pair<int, int>;

const int N = 305, MOD = 1e9 + 7, INF = 2e9;

int n, a[N], b[N], m, cnt, id[N][N], dp[2600][N], fact[N], inv[N], pre[N], suf[N];
vector<int> pos;
vector<PII> rng;
map<PII, bool> vis;

inline int fpow(int a, int b) {
    int ans = 1;
    for (; b; b >>= 1) {
        if (b & 1) ans = (LL)ans * a % MOD;
        a = (LL)a * a % MOD;
    }
    return ans;
}

void init() {
    int n = N - 1;
    fact[0] = 1;
    for (int i = 1; i <= n; i++) fact[i] = (LL)fact[i - 1] * i % MOD;
    inv[n] = fpow(fact[n], MOD - 2);
    for (int i = n - 1; i >= 0; i--) inv[i] = (LL)inv[i + 1] * (i + 1) % MOD;
}

void calc(int l, int r) { // 预处理所有需要用到的区间
    if (l > r) return;
    if (!vis.count({l, r})) {
        rng.push_back({l, r});
        vis[{l, r}] = true;
    }
    int mid = l + r >> 1;
    int len = r - l + 1;
    int L = max(l, mid - (len & 1)), R = min(r, mid + 1);
    for (int i = L; i <= R; i++) {
        calc(l, i - 1);
        calc(i + 1, r);
    }
}

inline void lagr(int n, int m) { // 插入 n 个点, 求多项式 F(m)
    if (n >= m) {
        for (int i = 1; i <= cnt; i++) dp[i][0] = dp[i][m];
        return;
    }
    pre[0] = 1;
    for (int i = 1; i <= n; i++) pre[i] = (LL)pre[i - 1] * (m - i) % MOD;
    suf[n + 1] = 1;
    for (int i = n; i >= 1; i--) suf[i] = (LL)suf[i + 1] * (m - i) % MOD;
    for (int i = 1; i <= cnt; i++) dp[i][0] = 0;
    int op = ((n - 1) & 1) ? MOD - 1 : 1;
    for (int i = 1; i <= n; i++) {
        int res = (LL)op * inv[i - 1] % MOD * inv[n - i] % MOD * pre[i - 1] % MOD * suf[i + 1] % MOD;
        op = MOD - op;
        for (int j = 1; j <= cnt; j++) dp[j][0] = (dp[j][0] + (LL)res * dp[j][i]) % MOD;
    }
}

int main() {
    // freopen("P5469_20.in", "r", stdin);

    cin >> n;
    for (int i = 1; i <= n; i++) {
        cin >> a[i] >> b[i];
        pos.push_back(a[i] - 1);
        pos.push_back(b[i]);
    }
    sort(pos.begin(), pos.end());
    auto it = unique(pos.begin(), pos.end());
    pos.erase(it, pos.end());
    m = pos.size();
    for (int i = 1; i <= n; i++) {
        a[i] = lower_bound(pos.begin(), pos.end(), a[i] - 1) - pos.begin();
        b[i] = lower_bound(pos.begin(), pos.end(), b[i]) - pos.begin() - 1;
    }

    init();

    calc(1, n);
    sort(rng.begin(), rng.end(), [&](auto i, auto j) { return i.second - i.first < j.second - j.first; });
    cnt = rng.size();

    for (int i = 0; i <= n + 1; i++) dp[0][i] = 1;

    for (int i = 0; i + 1 < m; i++) {
        int tot = min(n + 1, pos[i + 1] - pos[i]);
        for (int j = 1; j <= cnt; j++) {
            int l = rng[j - 1].first, r = rng[j - 1].second;
            id[l][r] = j;
            // for (int k = 1; k <= tot; k++) {
            //     for (int p = l; p <= r; p++) {
            //         if (abs((p - l) - (r - p)) <= 2 && a[p] <= i && i <= b[p]) {
            //             dp[j][k] = (dp[j][k] + (LL)dp[id[l][p - 1]][k] * dp[id[p + 1][r]][k - 1]) % MOD;
            //         }
            //     }
            //     dp[j][k] = (dp[j][k] + dp[j][k - 1]) % MOD;
            // }

            // 常数更优的实现方法
            for (int k = l; k <= r; k++) {
                if (abs((k - l) - (r - k)) <= 2 && a[k] <= i && i <= b[k]) {
                    for (int p = 1; p <= tot; p++) {
                        dp[j][p] = (dp[j][p] + (LL)dp[id[l][k - 1]][p] * dp[id[k + 1][r]][p - 1]) % MOD;
                    }
                }
            }
            for (int k = 1; k <= tot; k++) dp[j][k] = (dp[j][k] + dp[j][k - 1]) % MOD;
        }
        lagr(n + 1, pos[i + 1] - pos[i]);
        for (int j = 1; j <= cnt; j++) {
            for (int k = 1; k <= tot; k++) dp[j][k] = 0;
        }
    }
    cout << dp[cnt][0] << '\n';

    cerr << "Time: " << 1000 * clock() / CLOCKS_PER_SEC << "ms\n";
    return 0;
}
posted @ 2025-03-17 10:22  chenwenmo  阅读(94)  评论(1)    收藏  举报