CF549F Yura and Developers
这不今年完善程序 T2 吗?
给定数组 \(a_1\sim a_n\) 和常数 \(k\),求有多少个区间 \([l,r]\),满足:
\(r-l+1\ge 2\)。
\(\left(\sum\limits_{i=l}^r a_i-\max\limits_{i=l}^r a_i\right)\bmod k=0\)。
\(n\le 3\times 10^5\),\(k\le 10^6\)。
考虑分治。设当前分治区间左端点为 \(l\),右端点为 \(r\),中点为 \(mid\),考虑如何计算跨过中点的贡献。
首先对于右半区间,维护 \(pre_p=\max\limits_{u=mid+1}^p a_u\),\(sum_p=\left(\sum\limits_{v=mid+1}^pa_v\right)\bmod k\),\(dif_p=(sum_p-pre_p)\bmod k\),即以 \(mid+1\) 为起点的前缀最大值、模 \(k\) 意义下的前缀和以及它们的差对 \(k\) 取模的值。
从右往左扫描跨过中点的区间的左端点 \(i\),记 \(suf=\max\limits_{v=i}^{mid}a_v\),\(s=\left(\sum\limits_{u=i}^{mid} a_u\right)\bmod k\),我们要找到一个位置 \(j\),使得:
说白了就是右端点取在 \(j\) 及其左边区间最大值位于 \(mid\) 及其左边,右端点取在 \(j\) 右边区间最大值位于 \(mid\) 右边。不难发现随着 \(\boldsymbol i\) 递减,\(\boldsymbol j\) 不降。
分别计算以 \(j\) 为界的两部分的贡献,对于 \(j\) 及其左边,要找到这样的右端点 \(x\),使得 \((s+sum_x-suf)\bmod k=0\),移项得 \(sum_x=(k-s+suf)\bmod k\);对于 \(j\) 右边,要找到这样的右端点 \(y\),使得 \((s+dif_y)\bmod k=0\),移项得 \(dif_y=(k-s)\bmod k\)。
问题变成求 \((mid,j)\) 中有多少 \(sum\) 值为 \((k-s+suf)\bmod k\),\([j,r]\) 中有多少 \(dif\) 值为 \((k-s)\bmod k\),主席树维护即可,然后被卡了。
两只 \(\log\) 跑不快啊 /fn!
然后你发现主席树的可持久化是没有意义的,因为 \(j\) 不降,因此一个版本不会被再次使用。
考虑维护 \(b_1,b_2\) 两个桶,分别表示扫到当前的 \(j\) 时 \([j,r]\) 中 \(dif\) 在每种值各出现了几次和 \((mid,j)\) 中 \(sum\) 在每种值各出现了几次。\(j\) 增加到 \(j+1\) 时,相当于 \((mid,j)\) 比原来多包含了一个 \(j\) 位置,将 \(b_{2_{sum_{j}}}\) 增加 \(1\);\([j,r]\) 比原来少包含了一个 \(j\) 位置,将 \(b_{1_{dif_{j}}}\) 减去 \(1\)。
这么一来,左端点 \(i\) 的贡献为 \(b_{2_{(k-s+suf)\bmod k}}+b_{1_{(k-s)\bmod k}}\),注意边界、负数取模以及清空(有人又没清空,我不说是谁)。
时间复杂度为 \(\mathcal{O}(n\log n)\),空间复杂度为 \(\mathcal{O}(n+k)\)。
#include <bits/stdc++.h>
#define ll long long // 防见祖宗。
using namespace std; const int N = 3e5 + 5, V = 1e6;
int n, k, a[N], pre[N], sum[N], b1[V], b2[V]; // 代码里 dif 数组就直接用 pre 和 sum 表示了。
template<class T> void read(T &x) {
x = 0; T f = 1; char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') f = -1;
for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - 48; x *= f;
}
template<class T> void write(T x) {
if (x > 9) write(x / 10); putchar(x % 10 + 48);
}
template<class T> void print(T x, char ed = '\n') {
if (x < 0) putchar('-'), x = -x; write(x), putchar(ed);
}
ll solve(int l, int r) {
if (l == r) return 0; int md = (l + r) >> 1;
ll ret = solve(l, md) + solve(md + 1, r); pre[md] = sum[md] = 0;
for (int i = md + 1; i <= r; ++i)
pre[i] = max(pre[i - 1], a[i]), sum[i] = (sum[i - 1] + a[i]) % k;
for (int i = r; i > md; --i) ++b1[(sum[i] - pre[i] % k + k) % k];
for (int i = md, suf = 0, s = 0, j = md + 1, pos; i >= l; --i) {
suf = max(suf, a[i]); s = (s + a[i]) % k;
for (; j <= r && pre[j] <= suf; ++j)
--b1[(sum[j] - pre[j] % k + k) % k], ++b2[sum[j]];
if (j > md + 1) ret += b2[(k - s + suf) % k];
if (j <= r) ret += b1[(k - s) % k];
}
for (int i = r; i > md; --i) // 清空。
b1[(sum[i] - pre[i] % k + k) % k] = b2[sum[i]] = 0;
return ret;
}
signed main() {
read(n), read(k); for (int i = 1; i <= n; ++i) read(a[i]);
return print(solve(1, n)), 0;
}

浙公网安备 33010602011771号