[BZOJ3745][Coci2015]Norma

[BZOJ3745][Coci2015]Norma

试题描述

输入

第1行,一个整数N;
第2~n+1行,每行一个整数表示序列a。

输出

输出答案对10^9取模后的结果。

输入示例

4
2
4
1
4

输出示例

109

数据规模及约定

N <= 500000
1 <= a_i <= 10^8

题解

分治,然后分类讨论,考虑最大/最小值在左边还是在右边。

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;

int read() {
	int x = 0, f = 1; char c = getchar();
	while(!isdigit(c)){ if(c == '-') f = -1; c = getchar(); }
	while(isdigit(c)){ x = x * 10 + c - '0'; c = getchar(); }
	return x * f;
}

#define maxn 500010
#define MOD 1000000000
#define oo 2147483647
#define LL long long

int n, A[maxn], sl[maxn], smx[maxn], smn[maxn], slmx[maxn], slmn[maxn], sm[maxn], slm[maxn];

int solve(int l, int r) {
	if(l == r) return (LL)A[l] * A[l] % MOD;
	int mid = l + r >> 1, ans = solve(l, mid) + solve(mid + 1, r);
	if(ans >= MOD) ans -= MOD;
	sl[mid] = smx[mid] = smn[mid] = slmx[mid] = slmn[mid] = sm[mid] = slm[mid] = 0;
	int mx = 0, mn = oo;
	for(int i = mid + 1; i <= r; i++) {
		sl[i] = sl[i-1] + i - mid; if(sl[i] >= MOD) sl[i] -= MOD;
		mx = max(mx, A[i]); mn = min(mn, A[i]);
		smx[i] = smx[i-1] + mx; if(smx[i] >= MOD) smx[i] -= MOD;
		smn[i] = smn[i-1] + mn; if(smn[i] >= MOD) smn[i] -= MOD;
		slmx[i] = slmx[i-1] + (LL)(i - mid) * mx % MOD; if(slmx[i] >= MOD) slmx[i] -= MOD;
		slmn[i] = slmn[i-1] + (LL)(i - mid) * mn % MOD; if(slmn[i] >= MOD) slmn[i] -= MOD;
		sm[i] = sm[i-1] + (LL)mx * mn % MOD; if(sm[i] >= MOD) sm[i] -= MOD;
		slm[i] = slm[i-1] + (LL)(i - mid) * mx % MOD * mn % MOD; if(slm[i] >= MOD) slm[i] -= MOD;
	}
	mx = 0; mn = oo;
	int Sl = 0, Smx = 0, Smn = 0, Slmx = 0, Slmn = 0, Sm = 0, Slm = 0, mntr = mid + 1, mxtr = mid + 1;
	for(int i = mid; i >= l; i--) {
		mx = max(mx, A[i]); mn = min(mn, A[i]);
		Sl = mid - i + 1;
		Smx = mx;
		Smn = mn;
		Slmx = (LL)(mid - i + 1) * mx % MOD;
		Slmn = (LL)(mid - i + 1) * mn % MOD;
		Sm = (LL)mx * mn % MOD;
		Slm = (LL)(mid - i + 1) * mx % MOD * mn % MOD;
		while(mntr <= r && A[i] < A[mntr]) mntr++;
		while(mxtr <= r && A[i] > A[mxtr]) mxtr++;
		int tmp = min(mntr, mxtr);
		ans += ((LL)Slm * (tmp - mid - 1) + (LL)Sm * (sl[tmp-1] - sl[mid] + MOD)) % MOD;
		if(ans >= MOD) ans -= MOD;
		tmp = max(mntr, mxtr);
		ans += ((LL)Sl * (sm[r] - sm[tmp-1] + MOD) + slm[r] - slm[tmp-1] + MOD) % MOD;
		if(ans >= MOD) ans -= MOD;
		if(mntr < mxtr) {
			ans += ((LL)Slmx * (smn[mxtr-1] - smn[mntr-1] + MOD) + (LL)Smx * (slmn[mxtr-1] - slmn[mntr-1] + MOD)) % MOD;
			if(ans >= MOD) ans -= MOD;
		}
		else {
			ans += ((LL)Slmn * (smx[mntr-1] - smx[mxtr-1] + MOD) + (LL)Smn * (slmx[mntr-1] - slmx[mxtr-1] + MOD)) % MOD;
			if(ans >= MOD) ans -= MOD;
		}
//		printf("in[%d, %d] %d %d: %d  tr %d & %d\n", l, r, i, r, ans, mntr, mxtr);
	}
//	printf("[%d, %d] = %d\n", l, r, ans);
	return ans;
}

int main() {
	n = read();
	for(int i = 1; i <= n; i++) A[i] = read();
	
	printf("%d\n", solve(1, n));
	
	return 0;
}

 

posted @ 2017-09-20 19:11  xjr01  阅读(203)  评论(0编辑  收藏  举报