T1. 糖果购买

\(f(i, j, k)\) 为到第 \(i\) 个商店为止,总共买了 \(j\) 件物品,在当前商店买了 \(k\) 件商品时的最大价值和。那么我们有

\[f(i, j, k) = v_i \cdot k + \max_x \{ f(i-1, j-k, x) \} \]

首先注意到 \(f(i, *, *)\) 只和 \(f(i-1, *, *)\) 有关,所以可以使用滚动数组优化空间

然后我们可以通过维护 \(f(i, j, *)\) 的后缀最大值,将状态转移优化到 \(O(1)\)

最后,观察到,如果我们在某个商店买了 \(k\) 件物品,那么由于 \(c_i\) 的限制,我们买的总物品数量一定是不少于 \(\frac{k(k+1)}{2}\) 的。所以事实上我们不需要考虑 \(\frac{k(k+1)}{2} > m\) 的那些 \(k\)。这样的话 \(k\) 一维的大小就只有 \(O(\sqrt{m})\) 级别。

最终的时间复杂度为 \(O(nm\sqrt{m})\),空间复杂度为 \(O(m\sqrt{m})\)

代码实现
#include <bits/stdc++.h>
#define rep(i, n) for (int i = 0; i < (n); ++i)
#define drep(i, n) for (int i = (n)-1; i >= 0; --i)

using namespace std;
using ll = long long;

inline void chmax(ll& a, ll b) { if (a < b) a = b; }

int main() {
    int n, m;
    cin >> n >> m;
    
    vector<int> v(n);
    rep(i, n) cin >> v[i];
	
	vector<vector<ll>> dp(n+1, vector<ll>(m+1));
	int lim = sqrt(2*m)+1; 
	drep(i, n)rep(j, m+1) {
	    dp[i][j] = dp[i+1][j];
	    ll sum = 0;
	    for (int k = i; k < min(n, i+lim); ++k) {
	        sum += v[k];
	        int len = k-i+1;
	        if (len > j) break;
	        chmax(dp[i][j], dp[i+1][j-len] + sum);
	    }
	    if (j) chmax(dp[i][j], dp[i][j-1]);
	}
	
	cout << dp[0][m] << '\n';
	
	return 0;
}

T2. 序列制作

我们建一张 \(n\) 个点的图,然后对于每个 \(i(1 \leqslant i \leqslant n)\) ,我们给点 \(i\) 和点 \(a_i\) 之间连一条有向边。那么实际上原题等价于,求有多少种给每个点染色的方案,使得相邻两个点的颜色不同。

建图的方式决定了这张图一定是一个基环树森林。可以观察到,孤立点可以染的颜色有 \(m\) 种,基环树上不在环上的点有 \(m-1\) 种染色方案。对于环上的 \(m\) 染色问题其实就是 圆环三染色

总复杂度为 \(O(n)\)

代码实现
#include <bits/stdc++.h>
#define rep(i, n) for (int i = 0; i < (n); ++i)

using namespace std;
using ll = long long;

const int mod = 998244353;
//const int mod = 1000000007;
struct mint {
    ll x;
    mint(ll x=0):x((x%mod+mod)%mod) {}
    mint operator-() const {
        return mint(-x);
    }
    mint& operator+=(const mint a) {
        if ((x += a.x) >= mod) x -= mod;
        return *this;
    }
    mint& operator-=(const mint a) {
        if ((x += mod-a.x) >= mod) x -= mod;
        return *this;
    }
    mint& operator*=(const mint a) {
        (x *= a.x) %= mod;
        return *this;
    }
    mint operator+(const mint a) const {
        return mint(*this) += a;
    }
    mint operator-(const mint a) const {
        return mint(*this) -= a;
    }
    mint operator*(const mint a) const {
        return mint(*this) *= a;
    }
    mint pow(ll t) const {
        if (!t) return 1;
        mint a = pow(t>>1);
        a *= a;
        if (t&1) a *= *this;
        return a;
    }

    // for prime mod
    mint inv() const {
        return pow(mod-2);
    }
    mint& operator/=(const mint a) {
        return *this *= a.inv();
    }
    mint operator/(const mint a) const {
        return mint(*this) /= a;
    }
};
istream& operator>>(istream& is, mint& a) {
    return is >> a.x;
}
ostream& operator<<(ostream& os, const mint& a) {
    return os << a.x;
}

struct modinv {
  int n; vector<mint> d;
  modinv(): n(2), d({0,1}) {}
  mint operator()(int i) {
    while (n <= i) d.push_back(-d[mod%n]*(mod/n)), ++n;
    return d[i];
  }
  mint operator[](int i) const { return d[i];}
} invs;
struct modfact {
  int n; vector<mint> d;
  modfact(): n(2), d({1,1}) {}
  mint operator()(int i) {
    while (n <= i) d.push_back(d.back()*n), ++n;
    return d[i];
  }
  mint operator[](int i) const { return d[i];}
} facts;
struct modfactinv {
  int n; vector<mint> d;
  modfactinv(): n(2), d({1,1}) {}
  mint operator()(int i) {
    while (n <= i) d.push_back(d.back()*invs(n)), ++n;
    return d[i];
  }
  mint operator[](int i) const { return d[i];}
} ifacts;
mint comb(int n, int k) {
  if (n < k || k < 0) return 0;
  return facts(n)*ifacts(k)*ifacts(n-k);
}

void solve() {
    int n, m;
    cin >> n >> m;
    
    vector<int> a(n);
    rep(i, n) cin >> a[i], a[i]--;
    
    mint ans = 1;
    ll tot = 0;
    
    vector<int> ord(n, -1);
    auto dfs = [&](auto& f, int v, int k=1) -> void {
        if (ord[v] == -1) {
            ord[v] = k;
            f(f, a[v], k+1);
        }
        else if (ord[v] > 0) {
            int l = k-ord[v];
            mint now = mint(m-1).pow(l);
            if (l&1) now -= m-1;
            else now += m-1;
            ans *= now;
            tot += l;
        }
        ord[v] = 0;
    };
    rep(i, n) dfs(dfs, i);
    
    ans *= mint(m-1).pow(n-tot);
    
    cout << ans << '\n';
}

int main() {
    cin.tie(nullptr) -> sync_with_stdio(false);
    
    int t;
    cin >> t;
    
    while (t--) solve();
    
	return 0;
}

T3. 符号翻转

对于固定的 \(k\),区间 \([1, P]\) 的最大前缀和可表示为:

\[c(k, P) = S_i - 2(\text{区间}[1, P] \text{内} k \text{个最小负数的和}) \]

其中 \(S_i\) 为原始前缀和。这是因为翻转一个元素 \(a_i\) 带来的成本是 \(-2a_i\)

对于每个固定的 \(k\),我们需要找到使 \(c(k, P)\) 最大的 \(P\) 值,记为 \(P_k\)。如果有多个 \(P\) 都能使 \(c(k, P)\) 达到最大值,则取其中最小的那个作为 \(P_k\)(此处只是任意约定,对结果无影响)。

断言:\(P_k \leqslant P_{k+1}\)

证明:
通常,证明此类不等式的技巧是反证法。假设 \(P_k > P_{k+1}\)
\(L_k\)\(R_k\) 分别表示区间 \([1, P_{k+1}]\)\([P_{k+1}+1, P_k]\) 中最小的负数之和。注意到 \(L\)\(R\) 都是凸的。
假设在 \(c(k, P_k)\) 的最优解中,我们在区间 \([1, P_{k+1}]\) 中翻转了 \(k_1\) 个元素,在区间 \([P_{k+1}+1, P_k]\) 中翻转了 \(k_2\) 个元素。
\(P_k\) 的定义,必有 $$c(k, P_{k+1}) < c(k, P_k)$$
即:$$S_{P_{k+1}} - L_k < S_{P_k} - L_{k_1} - R_{k_2}$$
现在,\(c(k+1, P_k) \geqslant S_{P_k} - L_{k_1+1} - R_{k_2}\)
因为我们通过在区间 \([1, P_{k+1}]\) 翻转 \(k_1+1\) 个元素、在区间 \([P_{k+1}+1, P_k]\) 翻转 \(k_2\) 个元素来达到这个代价。由于 \(L\) 是凸的,故有 \(L_k - L_{k+1} \leqslant L_{k_1} - L_{k_1+1}\)
综合所有不等式可得:$$c(k+1, P_{k+1}) = S_{P_{k+1}} - L_{K+1} < S_{P_k} - L_{k_1+1} - R_{k_2} \leqslant c(k+1, P_k)$$ 这与 \(P_{k+1}\) 最优性的假设矛盾。

基于该性质,我们可以采用分治策略。
具体来说,我们可以构建函数 solve(l, r, optl, optr),当需要计算中间位置 \(\lfloor\frac{l+r}{2}\rfloor\) 的最优解时,可以确定 \(\mathrm{optl} \leqslant P_m \leqslant \mathrm{optr}\)

现在,如果我们能对任意的 \(m\)\(k\) 快速求出 \(c(k, m)\),就能解决这道题。这个问题归结为在区间 \([1, m]\) 中找 \(k\) 个最小元素,使用 \(\mathrm{wavelet} \ \mathrm{tree}\) 可以在 \(O(\log n)\) 时间内完成。不过,还有一种更简单的方法。

考虑一个数据结构,它维护一个可重集 \(S\) 和一个整数 \(k\),并能在对数时间内支持以下操作:

  • \(k\) 增加或减少 \(1\)
  • \(S\) 中插入或删除元素
  • 获取 \(S\) 中最小的 \(\min(k, |S|)\) 个元素之和

这个数据结构可以通过维护两个 multiset \(A\)\(B\) 来实现,要求:

  • \(\max(A) \leqslant \min(B)\)
  • 要么 \(|A| = k\),要么 \(|a| < k\)\(B = \varnothing\)

这样,查询的答案就是集合 \(A\) 中元素的和。维护这两个集合非常简单。

总的时间复杂度为 \(O(N\log^2 N)\)