CF 2127F Hamed and AghaBalaSar

怎么拆贡献又把自己拆乱了,我也是完蛋了。

首先考虑这个 \(f(a)\) 是什么。
发现跳(仅限第一个操作)的时候形如 \(x\to nxt(x)\to nxt(nxt(x))\to \cdots\),贡献就为 \((a_{nxt(x)} - a_x) + (a_{nxt(nxt(x))} - a_{nxt(x)})\)
于是在抵消之后,如果 \(x\) 最终跳到了 \(y\),那么贡献就是 \(a_y - a_x\)

然后会发现在给定了 \(a_n = \max\{a_1, a_2, \cdots, a_n\}\) 的情况下,最后跳到的 \(y\) 一定是离 \(x\) 最近的 \(a_y = a_n\)\(y\),然后若 \(y + 1 < n\) 就会从 \(y + 1\) 继续跳下去。

于是可以写出 \(f(a)\) 的表达式:\(a_n\sum\limits_{i = 1}^n[a_i = a_n] - \sum\limits_{i = 2}^n a_i[a_{i -1} = a_n] - a_1\)

接下来考虑计数。

观察到 \(f(a)\) 的表达式与 \(a_n\),也就是最大值有关,并且整个序列都受最大值的限制。
于是先尝试在外层枚举 \(a_n = mx\)

此时尝试分开计算两部分的贡献:

  1. \(a_n\sum\limits_{i = 1}^n [a_i = a_n]\)
  2. \(\sum\limits_{i = 2}^n a_i[a_{i - 1} = a_n] + a_1\)

首先考虑第 \(1\) 部分,\(a_n\) 因为是枚举的可以抛开不看。
那么这个时候就可以把 \(\sum\) 拆开转为对每个 \(i\) 计算,会发现有 \(2\) 种情况:

  1. \(i = n\),这是因为 \(a_n\) 的值是固定的,而其他的值不是。
  2. \(1\le i < n\)

对于第 \(1\) 种情况,发现此时就需要解决一个问题:
计算 \(f(n, m, mx)\) 代表长度为 \(n\),值域为 \([0, mx]\),和为 \(m\) 的序列数量。

对此并没有什么好求的做法,于是只能考虑容斥 \(> mx\) 的数的数量,可以得到:
\(f(n, m, mx) = \displaystyle\sum\limits_{i = 0}^n (-1)^i\dbinom{n}{i}\dbinom{m - i(mx + 1) + n - 1}{n - 1}\)

看似好像求解就需要 \(\mathcal{O}(n)\) 的复杂度,但是实际上只有 \(i(mx + 1)\le m\)\(i\) 是有用的,所以复杂度是 \(\mathcal{O}(\min\{n, \frac{m}{mx}\})\)
那么正好 \(mx\) 是从 \(1\)\(n\) 的,总复杂度就是 \(\mathcal{O}(m\ln m)\),看样子就很有道理。

\(1\) 种情况的方案数就为 \(f(n - 1, m - mx, mx)\)

然后来考虑第 \(2\) 种情况。
因为每个元素只关心是不是 \(= mx\),所以能够发现 \(1\sim n - 1\)\(n - 1\) 个位置的情况是同样的,那么就只需要求出一个下标的值再乘上 \((n - 1)\) 就可以了。
单个下标是好算的,相当于此时的限制是 \(a_i = a_n = mx\),那么方案数就为 \(f(n - 2, m - 2mx, mx)\)

于是第 \(2\) 种情况的方案数就为 \(f(n - 2, m - 2mx, mx)\)

接下来来考虑第 \(2\) 部分。

一样的尝试分讨,分为以下 \(3\) 种情况:

  • \(i = 1\)
  • \(1 < i < n\)
  • \(i = n\)

对于第 \(1\) 种情况,如果要枚举 \(a_1\) 的值复杂度就爆掉了,于是尝试其他的做法。
考虑 \(f(n - 1, m - mx, mx)\) 种合法的 \(a_1\sim a_{n - 1}\) 的方案,这 \(n - 1\) 个元素的和都为 \(m - mx\)
又因为这 \(n - 1\) 个元素是平等的,所以期望下应当有 \(a_1 = \frac{m - mx}{n - 1}\)
或者说,也可以考虑一个合法的 \(a_1\sim a_{n - 1}\) 的方案,考虑把其所有排列(根据值相同定义排列相同),那么每个位置在所有排列中的对应值的和应当是相等的,对所有等价类分析,就可以得到这个结果。
于是对应的答案就为 \(f(n - 1, m - mx, mx)\times \frac{m - mx}{n -1}\)

对于第 \(2\) 种情况,依然是只考虑一个元素,最后乘上 \(n - 2\)
对于一个元素,那就是要求 \(a_{i - 1} = a_n = mx\)
于是答案就是 \(f(n - 2, m - 2mx, mx)\times (n - 2)\times \frac{m - 2mx}{n - 2}\)

对于第 \(3\) 种情况,那就是要求 \(a_{n - 1} = a_n = mx\)
于是答案是 \(f(n - 2, m - 2mx, mx)\times mx\)

最后的时间复杂度为 \(\mathcal{O}(m\ln m + \log \operatorname{mod})\)

其实在 \(n = 2\) 的时候应该判一下第 \(2\) 部分的第 \(2\) 情况的,因为会涉及到除 \(0\),不过因为本身会乘上 \(n - 2 = 0\) 所以其实随便乘上一个不对的数也没什么问题。

#include <bits/stdc++.h>

using ll = long long;

constexpr ll mod = 1000000007;
constexpr int N = 4e5;

inline ll qpow(ll a, ll b) {
    ll v = 1;
    for (; b; b >>= 1, a = a * a % mod) {
        if (b & 1) {
            v = v * a % mod;
        }
    }
    return v;
}

ll fac[N + 1], ifac[N + 1];
inline void init() {
    for (int i = fac[0] = 1; i <= N; i++) {
        fac[i] = fac[i - 1] * i % mod;
    }
    ifac[N] = qpow(fac[N], mod - 2);
    for (int i = N; i >= 1; i--) {
        ifac[i - 1] = ifac[i] * i % mod;
    }
}


inline ll C(int n, int m) {
    return n < m || m < 0 ? 0ll : (fac[n] * ifac[n - m] % mod * ifac[m] % mod);
}
inline ll f(int n, int m, int mx) {
    if (m < 0) {
        return 0;
    }
    if (n == 0) {
        return m == 0;
    }
    ll ans = 0;
    for (int i = 0; i * (mx + 1) <= m && i <= n; i++) {
        ll res = C(n, i) * C(m - i * (mx + 1) + n - 1, n - 1) % mod;
        ans = (ans + mod + (i & 1 ? -res : res)) % mod;
    }
    return ans;
}

inline void solve() {
    int n, m;
    scanf("%d%d", &n, &m);
    ll inv1 = qpow(n - 1, mod - 2), inv2 = qpow(n - 2, mod - 2);
    ll ans = 0;
    for (int mx = 0; mx <= m; mx++) {
        ll x1 = f(n - 1, m - mx, mx) * mx % mod;
        ll x2 = f(n - 2, m - mx * 2, mx) * (n - 1) % mod * mx % mod;
        ll y1 = f(n - 1, m - mx, mx) * (m - mx) % mod * inv1 % mod;
        ll y2 = f(n - 2, m - mx * 2, mx) * (n - 2) % mod * (m - mx * 2) % mod * inv2 % mod;
        ll y3 = f(n - 2, m - mx * 2, mx) * mx % mod;
        ans = (ans + x1 + x2 - y1 - y2 - y3 + mod * 3) % mod;
    }
    printf("%lld\n", ans);
}

int main() {
    init();
    int t;
    scanf("%d", &t);
    while (t--) {
        solve();
    }
    return 0;
}
posted @ 2025-09-17 19:32  rizynvu  阅读(27)  评论(0)    收藏  举报