把博客园图标替换成自己的图标
把博客园图标替换成自己的图标end

数据结构 题目分析(牛逼做法非题解做法)

概述

给你一个长度为 \(n\) 的排列 \(p\),定义集合:

\[S_i=\{x\mid x\geq i\wedge \max_{k\in[i,x-1]}p_k<p_x\} \]

\(q\) 组询问,每个询问询问一个区间 \([l,r]\),让你求 \(\sum_{l\leq x,y\leq r}|S_x\cup S_y|\)

思路

说明一下,题解的做法是把他转化成 \(|S_x|+|S_y|-|S_x\cap S_y|\) 然后求后者。

但是此题解的思路来自翁巨和罗大神,可以不用直接转化。

先观察 \(S_x\) 看起来就跟题解说的一样:从后往前的升序单调栈。

但是我们换种思路,求 \(x\) 在哪些 \(S_i\) 中,显然可以用单调栈实现。

那么 \(x\) 存在的 \(S_i\) 一定是一个区间 \([l,r]\)

假设查询的区间是 \([L,R]\),对于 \(x\) 来讲包含它的 \(S_i\)\(k\) 个(在这个区间中),并设 \(len=R-L+1\)

那么它会贡献答案多少呢?

显然是 \(1\),我们再看看那些 \(S_x\)\(S_y\) 取并集后含有 \(x\) 的方案即可。

首先分个类:

  • 两者都有 \(x\),方案为 \(k^2\)
  • 其中一者有 \(x\),方案为 \(k\times(len-k)\times 2\)

总贡献就是:\(k^2+k\times(len-k)\times2=2len\times k - k^2\)

对于每个 \(x\) 都要求,所以对于整个区间的贡献就是:

\[2len\sum_{i\in[L,R]}k_{a_i} - \sum_{i\in[L,R]}{k_{a_i}}^2 \]

这样预处理过后做是 \(\mathcal{O}(n+nq)\) 的,可以拿很多分(\(60\) 分吧?)。

我们关注一下那个 \(k\),是要现场求而且是跟区间叠加有关,这让我们想到了分类讨论(这样的分类讨论必须不重不漏,要注意细节,我的代码是通过排序来说明以下讨论的情况限制的)。

假设 \(x\) 的区间为 \([l,r]\)

第一种情况:\(l<L\leq r\leq R\)

这里的 \(k\)\(r-L+1\)

没错,我们可以把问题离线下来按照左端点排序并用线段树求平方和以及和,从左到右插入 \(r\),并把贡献的部分拆开换成平方和以及和的形式进行计算即可。

第二种情况:\(L\leq l\leq R<r\)

同理。

第三种情况(被包含情况):\(L\leq l\leq r\leq R\)

这里可以按照 \(l\) 从大到小排序,插入 \(r-l+1\) 即可。

这里的 \(k=r-l+1\)

第四种情况(包含情况):\(l<L,r<R\)

这里的贡献是 \(k=R-L+1\),统计这种情况的个数就可以了。

代码

代码如下,时间复杂度 \(\mathcal{O}((n+q)\log n+q)\),本人错了一次样例后一遍过。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <stdlib.h>
#include <algorithm>
#include <vector>
#include <stack>
#define int long long
#define N 250004
using namespace std;
#define isdigit(ch) ('0' <= ch && ch <= '9')
template<typename T>
void read(T &x) {
	x = 0;
	int f = 1;
	char ch = getchar();
	for (;!isdigit(ch);ch = getchar()) f = (ch == '-' ? -1 : f);
	for (;isdigit(ch);ch = getchar()) x = (x << 3) + (x << 1) + (ch ^ 48);
	x *= f; 
}
template<typename T>
void write(T x) {
	if (x < 0) x = -x,putchar('-');
	if (x > 9) write(x / 10);
	putchar(x % 10 + '0');
}
#define ls(x) (x << 1)
#define rs(x) (x << 1 | 1) 
struct TR{
	int val1,val2,val3;
}tr[N << 2];
void build(int x,int l,int r) {
	tr[x] = {0,0,0};
	if (l == r) return;
	int mid = l + r >> 1;
	build(ls(x),l,mid),build(rs(x),mid + 1,r);
}
void pushup(int x) {
	tr[x] = {tr[ls(x)].val1 + tr[rs(x)].val1,tr[ls(x)].val2 + tr[rs(x)].val2,tr[ls(x)].val3 + tr[rs(x)].val3};
}
void update(int x,int l,int r,int pos,int val) {
	if (l == r) {
		tr[x].val1 += val;
		tr[x].val2 += val * val;
		tr[x].val3 ++;
		return;
	}
	int mid = l + r >> 1;
	if (pos <= mid) update(ls(x),l,mid,pos,val);
	else update(rs(x),mid + 1,r,pos,val);
	pushup(x);
}
TR query(int x,int l,int r,int L,int R) {
	if (l > R || r < L) return {0,0,0};
	if (L <= l && r <= R) return tr[x];
	int mid = l + r >> 1;
	TR t1 = query(ls(x),l,mid,L,R),t2 = query(rs(x),mid + 1,r,L,R);
	return {t1.val1 + t2.val1,t1.val2 + t2.val2,t1.val3 + t2.val3};
}
int n,T,type,ans[N];
struct abc{
	int l,r;
}qu[N],exch[N];
int a[N];
struct node{
	int l,r,type,id;
}s[N << 1];
int cnt;
stack<int> sta;
signed main(){
	read(n),read(T),read(type);
	for (int i = 1;i <= n;i ++) read(a[i]);
	for (int i = 1;i <= T;i ++) read(qu[i].l),read(qu[i].r);
	for (int i = 1;i <= T;i ++) s[i] = {qu[i].l,qu[i].r,1,i};
	cnt = T;
	for (int i = n;i;i --) {
		exch[i].r = i;
		while(!sta.empty() && a[sta.top()] < a[i]) {
			exch[sta.top()].l = i + 1;
			sta.pop();
		}
		sta.push(i);
	}
	while(!sta.empty()) exch[sta.top()].l = 1,sta.pop();
	for (int i = 1;i <= n;i ++) s[++cnt] = {exch[i].l,exch[i].r,2,0};
//    for (int i = 1;i <= cnt;i ++) cout << s[i].l << ' ' << s[i].r << ' ' << s[i].type << ' ' << s[i].id << '\n';
	stable_sort(s + 1,s + 1 + cnt,[](node x,node y) {
		if (x.l != y.l) return x.l < y.l;
		if (x.type != y.type) return x.type < y.type;
		return x.r < y.r;
	});
	//1 4
	for (int i = 1;i <= cnt;i ++)
		if (s[i].type == 2) update(1,1,n,s[i].r,s[i].r);//x^2,x
		else {
			int L = s[i].l,R = s[i].r,id = s[i].id;
			int len = R - L + 1;
			int tk = L - 1;
			TR ge = query(1,1,n,L,R);
			ans[id] += 2 * len * (ge.val1 - ge.val3 * tk) - (ge.val2 - 2 * tk * ge.val1 + ge.val3 * tk * tk);
			TR ge2 = query(1,1,n,R + 1,n);
			ans[id] += len * ge2.val3 * len;//2*len\sum k-2 * len\sumk^2 k=r-l+1=len -> 2*len^2*cnt-len^2*cnt.
		}
	//3
	build(1,1,n);
	stable_sort(s + 1,s + 1 + cnt,[](node x,node y) {
		if (x.l != y.l) return x.l > y.l;
		if (x.type != y.type) return x.type > y.type;//
		return x.r < y.r;
	});
	for (int i = 1;i <= cnt;i ++)
		if (s[i].type == 2) update(1,1,n,s[i].r,s[i].r - s[i].l + 1);
		else {
			int L = s[i].l,R = s[i].r,id = s[i].id;
			int len = R - L + 1;
			TR ge = query(1,1,n,L,R);
			ans[id] += 2 * len * ge.val1 - ge.val2;
		}
//	for (int i = 1;i <= T;i ++) cout << ans[i] << "\n"[i != T]; 
	build(1,1,n);
	stable_sort(s + 1,s + 1 + cnt,[](node x,node y) {
		if (x.r != y.r) return x.r > y.r;
		if (x.type != y.type) return x.type < y.type;
		return x.l < y.l;
	});
	for (int i = 1;i <= cnt;i ++)
		if (s[i].type == 2) update(1,1,n,s[i].l,s[i].l);
		else {
			int L = s[i].l,R = s[i].r,id = s[i].id;
			int len = R - L + 1;
			int tk = R + 1;
			TR ge = query(1,1,n,L,R);
			ans[id] += 2 * len * (ge.val3 * tk - ge.val1) - (ge.val3 * tk * tk - 2 * tk * ge.val1 + ge.val2);
		}
	for (int i = 1;i <= T;i ++) printf("%lld\n",ans[i]);
	return 0;
}
posted @ 2025-08-01 11:08  high_skyy  阅读(14)  评论(0)    收藏  举报
浏览器标题切换
浏览器标题切换end