LG7882 [Ynoi2006] rsrams【阈值法,分块,莫队】

给定长为 \(n\) 的序列 \(a_1,\cdots,a_n\)\(m\) 次询问区间 \([L,R]\),求其所有子区间的绝对众数之和。

\(n,m\le 10^6\)\(1\le a_i\le n\),时限 \(8.0\text{s}\)


若固定绝对众数是 \(x\),要求多少个子区间的 \(2[a_i=x]-1\) 之和 \(>0\),取前缀和之后问题就是区间顺序对计数。

优化当然考虑对出现次数 \(c\) 根号分治。对于总和 \(>0\) 的区间 \([l,r]\),在前缀和 \(=0\) 的位置分段,每段都满足第一个元素是 \(1\) 且所有前缀和 \(\ge 0\),或者左右翻转的情况。所以我们枚举每个 \(1\) 的位置 \(p\),向前后分别扩展出后缀和/前缀和均 \(\ge 0\) 的极长区间 \((l,p]\)\([p,r)\),将区间 \([l,r]\) 合并到同一个连通块,此时顺序对只会在一个连通区间内出现,且这样的区间的总长是 \(\mathcal O(c)\) 的。

对于长度 \(>\sqrt m\) 的区间,只有 \(\mathcal O(n/\sqrt m)\) 个,分别计算贡献到询问,就是在这个区间内做 \(m\) 次查询的莫队,注意不能直接对询问排序,你需要事先将所有询问按右端点排序,每次做莫队时对左端点所在块做桶排序,时间复杂度 \(\mathcal O(n\sqrt m)\)

对于长度 \(\le\sqrt m\) 的区间,总的顺序对个数只有 \(\mathcal O(n\sqrt m)\),查询次数 \(\mathcal O(\sqrt m)\),使用 \(\mathcal O(1)\) 修改 \(\mathcal O(\sqrt n)\) 查询的分块维护二维数点,时间复杂度 \(\mathcal O(n\sqrt m+m\sqrt n)\)

#include<bits/stdc++.h>
#define fi first
#define se second
using namespace std;
typedef long long LL;
typedef pair<int, int> pii;
const int N = 1000003;
int n, m, nb, mb, blen, bl[N], a[N], sum[N * 2], *buc = sum + N;
vector<int> v[N], ins[N], del[N];
LL ans[N];
struct Query {
	int l, r, id;
	Query(int _1 = 0, int _2 = 0, int _3 = 0): l(_1), r(_2), id(_3){}
	bool operator < (const Query &o) const {return r < o.r;}
} q[N];
vector<pii> tq[N];
struct BIT {
	LL val[N], sum[1003];
	void upd(int p, int v){val[p] += v; sum[p / nb] += v;}
	LL qry(int p){
		LL res = 0;
		for(int i = p / nb - 1;i >= 0;-- i) res += sum[i];
		for(int i = p / nb * nb;i <= p;++ i) res += val[i];
		return res;
	}
} tr;
int main(){
	ios::sync_with_stdio(0);
	cin >> n >> m; nb = sqrt(n); mb = sqrt(m);
	for(int i = 1;i <= n;++ i){
		cin >> a[i]; v[a[i]].push_back(i);
	}
	for(int i = 1;i <= m;++ i){
		cin >> q[i].l >> q[i].r; q[i].id = i;
		tq[q[i].l - 1].emplace_back(q[i].r, i);
		tq[q[i].r].emplace_back(q[i].r, i);
	}
	sort(q + 1, q + m + 1);
	vector<Query> lar, sma;
	for(int i = 1;i <= n;++ i) if(!v[i].empty()){
		vector<pii> t0, t1, t2;
		int nl = n + 1, nr = n + 1;
		for(int j = (int)v[i].size() - 1;j >= 0;-- j)
			if(v[i][j] < nl){
				if(nr <= n) t0.emplace_back(nl, nr);
				nr = v[i][j]; nl = max(nr - 1, 1);
			} else nl = max(nl - 2, 1);
		t0.emplace_back(nl, nr);
		reverse(t0.begin(), t0.end());
		nl = nr = 0;
		for(int j = 0;j < v[i].size();++ j)
			if(v[i][j] > nr){
				if(nl) t1.emplace_back(nl, nr);
				nl = v[i][j]; nr = min(nl + 1, n); 
			} else nr = min(nr + 2, n);
		t1.emplace_back(nl, nr);
		int p0 = 0, p1 = 0; nl = nr = 0;
		while(p0 < t0.size() || p1 < t1.size()){
			if(p0 != t0.size() && (p1 == t1.size() || t0[p0].fi < t1[p1].fi)){
				if(nr < t0[p0].fi){
					if(nl) t2.emplace_back(nl, nr);
					nl = t0[p0].fi; nr = t0[p0].se;
				} else nr = max(nr, t0[p0].se);
				++ p0;
			} else {
				if(nr < t1[p1].fi){
					if(nl) t2.emplace_back(nl, nr);
					nl = t1[p1].fi; nr = t1[p1].se;
				} else nr = max(nr, t1[p1].se);
				++ p1;
			}
		}
		t2.emplace_back(nl, nr);
		for(const auto &[L, R] : t2) (R - L >= mb ? lar : sma).emplace_back(L, R, i);
	}
	for(const auto &[ql, qr, val] : lar){
		vector<Query> nq;
		for(int i = 1;i <= m;++ i)
			if(q[i].l <= qr && q[i].r >= ql)
				nq.emplace_back(max(q[i].l, ql), min(q[i].r, qr), q[i].id);
		if(nq.empty()) continue;
		blen = max(1., (qr - ql + 1) / sqrt(nq.size()));
		for(int i = ql;i <= qr;++ i) bl[i] = (i - ql) / blen + 1;
		memset(sum, 0, (bl[qr] + 1) << 2);
		for(int i = 0;i < nq.size();++ i) ++ sum[bl[nq[i].l]];
		for(int i = 1;i <= bl[qr];++ i) sum[i] += sum[i - 1];
		vector<Query> nxtq(nq.size());
		for(int i = 0;i < nq.size();++ i) nxtq[sum[bl[nq[i].l] - 1] ++] = nq[i];
		nq.swap(nxtq); buc[0] = 1;
		int nl = ql, nr = ql - 1, sl = 0, sr = 0, ssl = 0, ssr = 0; LL res = 0;
		for(const auto &[l, r, id] : nq){
			while(nr < r){
				if(a[++ nr] == val) ssr += buc[sr ++];
				else ssr -= buc[-- sr];
				res += ssr; ++ buc[sr]; if(sl < sr) ++ ssl;
			}
			while(nl > l){
				if(a[-- nl] == val) ssl += buc[sl --];
				else ssl -= buc[++ sl];
				res += ssl; ++ buc[sl]; if(sl < sr) ++ ssr;
			}
			while(nr > r){
				res -= ssr; -- buc[sr]; if(sl < sr) -- ssl;
				if(a[nr --] == val) ssr -= buc[-- sr];
				else ssr += buc[sr ++];
			}
			while(nl < l){
				res -= ssl; -- buc[sl]; if(sl < sr) -- ssr;
				if(a[nl ++] == val) ssl -= buc[++ sl];
				else ssl += buc[sl --];
			}
			ans[id] += res * val;
		}
		int len = qr - ql + 1; memset(buc - len, 0, (len + 1) << 3);
	}
	for(int i = 0;i < sma.size();++ i){
		ins[sma[i].l].push_back(i);
		del[sma[i].r + 1].push_back(i);
	}
	set<int> st;
	for(int i = 1;i <= n;++ i){
		for(int j : ins[i]) st.insert(j);
		for(int j : del[i]) st.erase(j);
		for(int j : st)
			for(int k = i, nv = 0;k <= sma[j].r;++ k)
				if((nv += (a[k] == sma[j].id ? 1 : -1)) > 0)
					tr.upd(k, sma[j].id);
		for(auto [j, id] : tq[i]) ans[id] += (i == j ? 1 : -1) * tr.qry(j);
	}
	for(int i = 1;i <= m;++ i) cout << ans[i] << '\n';
}
posted @ 2022-08-02 22:21  mizu164  阅读(162)  评论(0编辑  收藏  举报