题解 [POI2011] MET-Meteors
题目链接
题目描述
以下是简单转化后的题目描述:
有一个长度为 \(m\) 的环状正整数序列(即 \(m\) 与 \(1\) 相邻),初始全为 \(0\)。
有 \(k\) 次区间加的操作。
有 \(n\) 个国家,正整数序列的每一个元素恰属于其中一个国家,第 \(i\) 个国家有一个目标值 \(p_i\)。
对每个国家,询问至少多少次区间加之后,国家内所有元素和不小于 \(p_i\),如果 \(k\) 次操作之后依然小于 \(p_i\),输出 NIE
其中 \(n, m, k \le 3 \times 10^5\)。
分析
虽然网上大部分整体二分题解复杂度都是 \(O(n \log^2 n)\),但实际上整体二分可以做到 \(O(n \log n)\) (虽然常数大)。
对于每个集合,答案显然具有单调性,即若 \(a\) 次区间加后恰好不小于 \(p_i\),那小于 \(a\) 次都不行,大于 \(a\) 次都行。
所以我们可以考虑整体二分。
即我们需要实现 solve(cl, cr, al, ar) 函数,表示解决国家在 \(cl\) 到 \(cr\) 之间,答案在 \(al\) 到 \(ar\) 之间的子问题。
我们设 \(mid = \frac{al+ar}{2}\),我们需要 solve 函数能将答案小于等于 \(mid\) 和大于 \(mid\) 的国家分开来,并递归处理。
假设在 \(cl\) 到 \(cr\) 之间所有国家共包含 \(n\) 个序列中的元素,\(al\) 到 \(ar\) 之间共有 \(m\) 次区间加的操作,solve 函数的时间复杂度是 \(O(T(n + m))\)。
那总复杂度就是 \(O(T(n) \log n)\),证明较为平凡,这里不再给出。
网上大部分题解都是 \(T(n) = n \log n\),但实际上可以做到 \(T(n) = n\)。
考虑 solve 函数需要干什么:
需要把操作编号在 \(al\) 到 \(mid\) 之间的区间加执行,并对于 \(cl\) 到 \(cr\) 之间的每个国家判断当前的和是否达到了目标。
如果单看这一个问题,这显然是 \(O(n)\) 的,因为所有的区间加都在判断之前,所以我们可以差分。
但为什么还要使用树状数组呢?
因为递归之后的子问题中,序列中的元素不再是连续的,这样再差分复杂度会炸。
其实解决方法也很简单,我们只需要离散化即可。
具体而言,我们破环成链,并将区间加差分成后缀加,并在初始时将所有后缀加与序列一起按照下标排序。
对于一个 solve 函数,我们先执行差分,并判断每个后缀加和当前序列中的每个元素应该递归到哪边。
之后对递归到左右两边的所有元素以下标为关键字各进行一遍基数排序,并得到所有元素在各自部分的排名,并让其等于新的下标,递归即可。
由于是基数排序,所以复杂度依然是 \(O(n)\)。
不过因为我们已经将后缀加与序列中的元素按照下标排序了,所以我们可以直接扫一遍以代替基数排序。
上面所有的操作复杂度都是 \(O(n)\)。
其实,不带修区间 \(k\) 小也能用类似的做法做到 \(O(n \log n)\)。
代码
卡空间卡了一上午...
还有这破题居然卡 long long,必须得上 int128 才能过(ull 因为没有负数所以无法实现差分)
#include <bits/stdc++.h>
using namespace std;
typedef __int128 LL;
const int N = 300010;
const int M = 1300010;
const int K = 700010;
struct node{
int type, t, pos, val; // pos 就是在当且 solve 函数中的下标,type 对应类型(1 为后缀加,2 为序列元素)
bool operator <(const node a)const{
return pos < a.pos;
}
}a[M], lft[K], rgt[K];
struct country{
int op, num, res, h;
LL nd;
}c[N];
int n, m, k, cnt, tot, bel[N], ans[N], nxt[N], to[N];
LL d[N];
unsigned long long sum[N];
inline int read()
{
int x = 0, f = 1;
char c = getchar();
while(c < '0' || c > '9') { if(c == '-') f = -1; c = getchar(); }
while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x * f;
}
void addedge(int u, int v)
{
to[++tot] = v;
nxt[tot] = c[u].h, c[u].h = tot;
}
void add(node *b, int &tot, int &pos, node cur)
{
b[++tot] = cur;
if(cur.pos == pos) b[tot].pos = b[tot - 1].pos;
else b[tot].pos = b[tot - 1].pos + 1;
pos = cur.pos;
// 得到新的下标
}
void solve(int pl, int pr, int cl, int cr, int al, int ar)
{
// 操作与序列元素在 a 数组中的下标在 pl 到 pr 之间(此时 pos 已离散化)
// 国家在 cl 到 cr 之间
// 答案在 al 到 ar 之间
if(al == ar){
for(int i = cl; i <= cr; i++) ans[c[i].num] = al;
return ;
}
int mid = (al + ar) >> 1, mx = 0;
for(int i = pl; i <= pr; i++){
if(a[i].type && a[i].t <= mid) d[a[i].pos] += a[i].val;
mx = max(mx, a[i].pos);
}
// 差分实现区间加
for(int i = 1; i <= mx; i++) d[i] += d[i - 1];
for(int i = pl; i <= pr; i++)
if(!a[i].type) sum[a[i].t] = d[a[i].pos];
for(int i = cl; i <= cr; i++){
LL cur = 0;
for(int j = c[i].h; j; j = nxt[j]) cur += sum[to[j]];
if(cur >= c[i].nd){
c[i].op = 0;
for(int j = c[i].h; j; j = nxt[j]) sum[to[j]] = 0;
}
else{
c[i].op = 1, c[i].nd -= cur;
for(int j = c[i].h; j; j = nxt[j]) sum[to[j]] = 1;
}
}
// 判断每个国家是否达到了目标
int tot1 = 0, tot2 = 0, pos1 = 0, pos2 = 0;
lft[0].pos = rgt[0].pos = 0;
for(int i = pl; i <= pr; i++){
if(a[i].type)
if(a[i].t <= mid) add(lft, tot1, pos1, a[i]);
else add(rgt, tot2, pos2, a[i]);
else
if(sum[a[i].t]) add(rgt, tot2, pos2, a[i]);
else add(lft, tot1, pos1, a[i]);
}
// 离散化
for(int i = pl; i < pl + tot1; i++) a[i] = lft[i - pl + 1];
for(int i = pl + tot1; i <= pr; i++) a[i] = rgt[i - pl - tot1 + 1];
int l = cl, r = cr;
while(l <= r){
while(!c[l].op) l++;
while(c[r].op) r--;
if(l <= r) swap(c[l], c[r]);
}
for(int i = 1; i <= mx; i++) d[i] = 0;
for(int i = pl; i <= pr; i++) sum[a[i].t] = 0;
solve(pl, pl + tot1 - 1, cl, l - 1, al, mid);
solve(pl + tot1, pr, l, cr, mid + 1, ar);
}
int main()
{
n = read(), m = read();
for(int i = 1; i <= n; i++) c[i].h = 0;
for(int i = 1; i <= m; i++)
bel[i] = read(), addedge(bel[i], i);
for(int i = 1; i <= n; i++)
c[i].num = i, c[i].nd = read();
k = read();
for(int i = 1, l, r, v; i <= k; i++){
l = read(), r = read(), v = read();
if(l <= r){
a[++cnt] = (node){1, i, l, v};
if(r + 1 <= m) a[++cnt] = (node){1, i, r + 1, -v};
}
else{
a[++cnt] = (node){1, i, 1, v};
a[++cnt] = (node){1, i, r + 1, -v};
a[++cnt] = (node){1, i, l, v};
}
}
for(int i = 1; i <= m; i++) a[++cnt] = (node){0, i, i, 0};
sort(a + 1, a + cnt + 1);
solve(1, cnt, 1, n, 1, k + 1);
for(int i = 1; i <= n; i++)
if(ans[i] == k + 1) puts("NIE");
else printf("%d\n", ans[i]);
return 0;
}

浙公网安备 33010602011771号