BZOJ1367

# 题解

$a_i$分成若干个单调递减的段，每段的答案为其中位数

#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<map>
#define Redge(u) for (int k = h[u],to; k; k = ed[k].nxt)
#define REP(i,n) for (int i = 1; i <= (n); i++)
#define mp(a,b) make_pair<int,int>(a,b)
#define cls(s) memset(s,0,sizeof(s))
#define cp pair<int,int>
#define LL long long int
using namespace std;
const int maxn = 1000005,maxm = 100005,INF = 1000000000;
int out = 0,flag = 1; char c = getchar();
while (c < 48 || c > 57){if (c == '-') flag = -1; c = getchar();}
while (c >= 48 && c <= 57){out = (out << 3) + (out << 1) + c - 48; c = getchar();}
return out * flag;
}
int val[maxn],ls[maxn],rs[maxn],d[maxn],siz[maxn],rt[maxn],Rt[maxn];
int merge(int a,int b){
if (!b) return a;
if (!a) return b;
if (val[b] < val[a]) swap(a,b);
rs[a] = merge(rs[a],b);
siz[a] = siz[ls[a]] + 1 + siz[rs[a]];
if (d[ls[a]] < d[rs[a]]) swap(ls[a],rs[a]);
d[a] = rs[a] ? d[rs[a]] + 1 : 0;
return a;
}
int Merge(int a,int b){
if (!b) return a;
if (!a) return b;
if (val[b] > val[a]) swap(a,b);
rs[a] = Merge(rs[a],b);
siz[a] = siz[ls[a]] + 1 + siz[rs[a]];
if (d[ls[a]] < d[rs[a]]) swap(ls[a],rs[a]);
d[a] = d[rs[a]] + 1;
return a;
}
int n,pos[maxn],len[maxn],K;
LL A[maxn];
void work(){
int tmp; d[0] = -1;
for (int i = 1; i <= n; i++){
pos[++K] = i; len[K] = 1; rt[i] = i; siz[rt[i]] = 1; val[i] = A[i];
while (K > 1 && val[rt[pos[K]]] < val[rt[pos[K - 1]]]){
K--;
rt[pos[K]] = merge(rt[pos[K]],rt[pos[K + 1]]);
Rt[pos[K]] = Merge(Rt[pos[K]],Rt[pos[K + 1]]);
len[K] += len[K + 1];
while (siz[rt[pos[K]]] > siz[Rt[pos[K]]]){
tmp = rt[pos[K]];
rt[pos[K]] = merge(ls[rt[pos[K]]],rs[rt[pos[K]]]);
ls[tmp] = rs[tmp] = 0; siz[tmp] = 1;
Rt[pos[K]] = Merge(Rt[pos[K]],tmp);
}
while (siz[rt[pos[K]]] < siz[Rt[pos[K]]]){
tmp = Rt[pos[K]];
Rt[pos[K]] = Merge(ls[Rt[pos[K]]],rs[Rt[pos[K]]]);
ls[tmp] = rs[tmp] = 0; siz[tmp] = 1;
rt[pos[K]] = merge(rt[pos[K]],tmp);
}
}
//printf("[%d,%d]  mid = %d\n",i - len[K] + 1,i,val[rt[pos[K]]]);
}
LL ans = 0,v;
for (int i = 1,l = 1; i <= K; i++){
v = siz[rt[pos[i]]] > siz[Rt[pos[i]]] ? val[rt[pos[i]]] : val[Rt[pos[i]]];
//printf("[%d,%d]  v = %lld\n",l,l + len[i] - 1,v);
for (int j = 0; j < len[i]; j++)
ans += abs(v - A[l + j]);
l += len[i];
}
printf("%lld\n",ans);
}
int main(){