CF549F Yura and Developers

这不今年完善程序 T2 吗?

CF 洛谷

  • 给定数组 \(a_1\sim a_n\) 和常数 \(k\),求有多少个区间 \([l,r]\),满足:

    • \(r-l+1\ge 2\)

    • \(\left(\sum\limits_{i=l}^r a_i-\max\limits_{i=l}^r a_i\right)\bmod k=0\)

  • \(n\le 3\times 10^5\)\(k\le 10^6\)

考虑分治。设当前分治区间左端点为 \(l\),右端点为 \(r\),中点为 \(mid\),考虑如何计算跨过中点的贡献。

首先对于右半区间,维护 \(pre_p=\max\limits_{u=mid+1}^p a_u\)\(sum_p=\left(\sum\limits_{v=mid+1}^pa_v\right)\bmod k\)\(dif_p=(sum_p-pre_p)\bmod k\),即以 \(mid+1\) 为起点的前缀最大值、模 \(k\) 意义下的前缀和以及它们的差对 \(k\) 取模的值。

从右往左扫描跨过中点的区间的左端点 \(i\),记 \(suf=\max\limits_{v=i}^{mid}a_v\)\(s=\left(\sum\limits_{u=i}^{mid} a_u\right)\bmod k\),我们要找到一个位置 \(j\),使得:

\[\max\limits_{u=mid+1}^{j-1} a_u\le suf\, \land \,\forall \,w\in[j,r],\max\limits_{u=mid+1}^w a_u> suf \]

说白了就是右端点取在 \(j\) 及其左边区间最大值位于 \(mid\) 及其左边,右端点取在 \(j\) 右边区间最大值位于 \(mid\) 右边。不难发现随着 \(\boldsymbol i\) 递减,\(\boldsymbol j\) 不降

分别计算以 \(j\) 为界的两部分的贡献,对于 \(j\) 及其左边,要找到这样的右端点 \(x\),使得 \((s+sum_x-suf)\bmod k=0\),移项得 \(sum_x=(k-s+suf)\bmod k\);对于 \(j\) 右边,要找到这样的右端点 \(y\),使得 \((s+dif_y)\bmod k=0\),移项得 \(dif_y=(k-s)\bmod k\)

问题变成求 \((mid,j)\) 中有多少 \(sum\) 值为 \((k-s+suf)\bmod k\)\([j,r]\) 中有多少 \(dif\) 值为 \((k-s)\bmod k\),主席树维护即可,然后被卡了

两只 \(\log\) 跑不快啊 /fn!

然后你发现主席树的可持久化是没有意义的,因为 \(j\) 不降,因此一个版本不会被再次使用。

考虑维护 \(b_1,b_2\) 两个桶,分别表示扫到当前的 \(j\)\([j,r]\)\(dif\) 在每种值各出现了几次和 \((mid,j)\)\(sum\) 在每种值各出现了几次。\(j\) 增加到 \(j+1\) 时,相当于 \((mid,j)\) 比原来多包含了一个 \(j\) 位置,将 \(b_{2_{sum_{j}}}\) 增加 \(1\)\([j,r]\) 比原来少包含了一个 \(j\) 位置,将 \(b_{1_{dif_{j}}}\) 减去 \(1\)

这么一来,左端点 \(i\) 的贡献为 \(b_{2_{(k-s+suf)\bmod k}}+b_{1_{(k-s)\bmod k}}\),注意边界、负数取模以及清空(有人又没清空,我不说是谁)。

时间复杂度为 \(\mathcal{O}(n\log n)\),空间复杂度为 \(\mathcal{O}(n+k)\)

提交记录

#include <bits/stdc++.h>
#define ll long long // 防见祖宗。
using namespace std; const int N = 3e5 + 5, V = 1e6; 
int n, k, a[N], pre[N], sum[N], b1[V], b2[V]; // 代码里 dif 数组就直接用 pre  和 sum 表示了。
template<class T> void read(T &x) {
    x = 0; T f = 1; char c = getchar();
    for (; !isdigit(c); c = getchar()) if (c == '-') f = -1;
    for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - 48; x *= f;
}
template<class T> void write(T x) {
    if (x > 9) write(x / 10); putchar(x % 10 + 48);
}
template<class T> void print(T x, char ed = '\n') {
    if (x < 0) putchar('-'), x = -x; write(x), putchar(ed);
}
ll solve(int l, int r) {
    if (l == r) return 0; int md = (l + r) >> 1; 
    ll ret = solve(l, md) + solve(md + 1, r); pre[md] = sum[md] = 0;
    for (int i = md + 1; i <= r; ++i) 
        pre[i] = max(pre[i - 1], a[i]), sum[i] = (sum[i - 1] + a[i]) % k;
    for (int i = r; i > md; --i) ++b1[(sum[i] - pre[i] % k + k) % k];
    for (int i = md, suf = 0, s = 0, j = md + 1, pos; i >= l; --i) {
        suf = max(suf, a[i]); s = (s + a[i]) % k; 
        for (; j <= r && pre[j] <= suf; ++j)
            --b1[(sum[j] - pre[j] % k + k) % k], ++b2[sum[j]];
        if (j > md + 1) ret += b2[(k - s + suf) % k];
        if (j <= r) ret += b1[(k - s) % k]; 
    }
    for (int i = r; i > md; --i) // 清空。
        b1[(sum[i] - pre[i] % k + k) % k] = b2[sum[i]] = 0;
    return ret;
}
signed main() {
    read(n), read(k); for (int i = 1; i <= n; ++i) read(a[i]);
    return print(solve(1, n)), 0;
}
posted @ 2023-09-23 10:43  lzyqwq  阅读(42)  评论(0)    收藏  举报