BZOJ4826 [Hnoi2017]影魔 【线段树 + 单调栈】

题目链接

BZOJ4826

题解

蒟蒻智力水平捉急orz

我们会发现相邻的\(i\)\(j\)贡献一定是\(p1\),可以很快算出来【然而我一开始忘了考虑调了半天】

我们现在只考虑不相邻的
我们只需要找出所有产生贡献的\(i,j\)即可
我们发现每一个产生贡献的\(i,j\)都能对应到一个三元组\((i,k,j)\),分别对应区间的最大值,次大值,第三大值
我们枚举中间位置\(i\),找到\(i\)左边第一个比\(i\)大的位置\(L[i]\),右边第一个比\(i\)大的位置\(R[i]\)
那么\(L[i]\)\(R[i]\)的贡献就是\(p1\)
区间\((L[i],i)\)\(R[i]\)的贡献是\(p2\)
区间\((i,R[i])\)\(L[i]\)的贡献是\(p2\)
一对点对询问产生贡献,当且仅当其都在区间中
所以我们可以用一个端点去储存贡献,而另一个端点作为更新的位置
然后再离线询问,用到\(r\)端点时区间\([l,r]\)的贡献减去到\(l - 1\)时区间\([l,r]\)的贡献,就是答案

#include<algorithm>
#include<iostream>
#include<cstdio>
#include<vector>
#include<cmath>
#define REP(i,n) for (int i = 1; i <= (n); i++)
#define LL long long int
#define ls (u << 1)
#define rs (u << 1 | 1)
using namespace std;
const int maxn = 200005,maxm = 100005,INF = 1000000000;
inline int read(){
	int out = 0,flag = 1; char c = getchar();
	while (c < 48 || c > 57){if (c == '-') flag = -1; c = getchar();}
	while (c >= 48 && c <= 57){out = (out << 3) + (out << 1) + c - 48; c = getchar();}
	return out * flag;
}
struct node{int l,r;};
struct Que{int l,r,pos,t,id;}q[maxn << 1];
vector<node> Ln[maxn],Rn[maxn];
LL ans[maxn];
LL A[maxn],n,m,p1,p2,qi;
LL st[maxn],top,Li[maxn],Ri[maxn];
LL sum[maxn << 2],tag[maxn << 2];
inline bool operator <(const Que& a,const Que& b){
	return a.pos < b.pos;
}
void upd(int u){sum[u] = sum[ls] + sum[rs];}
void pd(int u,int l,int r){
	if (tag[u]){
		int mid = l + r >> 1;
		sum[ls] += tag[u] * (mid - l + 1);
		sum[rs] += tag[u] * (r - mid);
		tag[ls] += tag[u]; tag[rs] += tag[u];
		tag[u] = 0;
	}
}
void add(int u,int l,int r,int L,int R,LL v){
	if (l >= L && r <= R){sum[u] += v * (r - l + 1); tag[u] += v; return;}
	pd(u,l,r);
	int mid = l + r >> 1;
	if (mid >= L) add(ls,l,mid,L,R,v);
	if (mid < R) add(rs,mid + 1,r,L,R,v);
	upd(u);
}
LL query(int u,int l,int r,int L,int R){
	if (l >= L && r <= R) return sum[u];
	pd(u,l,r);
	int mid = l + r >> 1;
	if (mid >= R) return query(ls,l,mid,L,R);
	if (mid < L) return query(rs,mid + 1,r,L,R);
	return query(ls,l,mid,L,R) + query(rs,mid + 1,r,L,R);
}
void init(){
	for (int i = 1; i <= n; i++){
		while (top && A[i] > A[st[top]]) top--;
		Li[i] = st[top];
		st[++top] = i;
	}
	st[0] = n + 1; top = 0;
	for (int i = n; i; i--){
		while (top && A[i] > A[st[top]]) top--;
		Ri[i] = st[top];
		st[++top] = i;
	}
	for (int i = 1; i <= n; i++){
		//printf("pos: %d   L:[%d,%d]  R:[%d,%d]\n",i,Li[i] + 1,i - 1,i + 1,Ri[i] - 1);
		if (Li[i])
			Ln[Li[i]].push_back((node){i + 1,Ri[i] - 1});
		if (Ri[i] <= n)
			Rn[Ri[i]].push_back((node){Li[i] + 1,i - 1});
	}
}
void solve(){
	sort(q + 1,q + 1 + qi); int l,r,pos = 1;
	for (int i = 0; i <= n; i++){
		for (unsigned int j = 0; j < Ln[i].size(); j++){
			l = Ln[i][j].l; r = Ln[i][j].r;
			if (l <= r) add(1,1,n,l,r,p2);
		}
		for (unsigned int j = 0; j < Rn[i].size(); j++){
			l = Rn[i][j].l; r = Rn[i][j].r;
			if (l <= r) add(1,1,n,l,r,p2);
			if (l - 1 > 0) add(1,1,n,l - 1,l - 1,p1);
		}
		while (pos <= qi && q[pos].pos == i){
			ans[q[pos].id] += query(1,1,n,q[pos].l,q[pos].r) * q[pos].t;
			pos++;
		}
	}
	for (int i = 1; i <= m; i++)
		printf("%lld\n",ans[i]);
}
int main(){
	n = read(); m = read(); p1 = read(); p2 = read(); int l,r;
	for (int i = 1; i <= n; i++) A[i] = read();
	for (int i = 1; i <= m; i++){
		l = read(); r = read();
		q[++qi] = (Que){l,r,l - 1,-1,i};
		q[++qi] = (Que){l,r,r,1,i};
		ans[i] += 1ll * (r - l) * p1;
	}
	init();
	solve();
	return 0;
}

posted @ 2018-05-15 20:54  Mychael  阅读(243)  评论(0编辑  收藏  举报