AtCoder Beginner Contest 248 H
https://atcoder.jp/contests/abc248/tasks/abc248_h
官方题解使用的是线段树,不过分治可能更简单一些。
我们调用\(divide(l,r)\),表示区间\([l,r]\)的合法子序列个数。
根据分治的套路:\(divide(l,r)=divide(l,mid)+divide(mid+1,r)+f\),其中\(f\)表示跨越左右两个侧的合法子序列个数。(这里和下面“左侧”指的是区间\([l,mid]\),“右侧”指的是区间\([mid+1,r]\))
问题是如何求\(f\)。
分为下面四种情况:
- 最大、最小值都在左侧
- 最大、最小值都在右侧
- 最大值在左侧,最小值在右侧
- 最大值在右侧,最小值在左侧
1,2两种情况类似,我们就讨论情况1,另一种程序是差不多的。
若最大、最小值都在左侧,那么我们枚举\(L\),表示合法区间的左端点,枚举\(K\),表示当前\(k\)的大小,那么:\(max-min=R-L+K\),可以变形成:\(R=L-K+max-min\)。而最小、最大值都在左侧,故我们可以知道\(max\)和\(min\)。接下来我们要判断\(R\)是否合法。首先\(mid<R<=r\),其次\(\min_{i=mid+1}^{R} p_{i}>\min_{i=L}^{mid} p_{i}\),\(\max_{i=mid+1}^{R} p_{i}<\max_{i=L}^{mid} p_{i}\),否则最大或最小值就不在左侧了。
预处理\(min,max\),时间复杂度为\(O(nk)\)。
3,4两种情况,我们只讨论4,另一种程序是差不多的。
我们继续枚举\(L\),意义同上,此时必须满足:
此时,我们发现,若\(L\)从大往小枚举(即从\(mid\)枚举到\(l\)),那么\(\min_{i=L}^{mid} p_{i}\)越来越小,\(\max_{i=L}^{mid} p_{i}\)越来越大。那么我们可以使用单调队列维护合法的\(R\)的位置。
那么如何进行统计呢?我们也是枚举\(K\),此时\(min,L,K\)都是已知的,而\(R,max\)是未知的,那么:
故我们维护一个桶\(cnt\),单调队列加入一个元素时,\(cnt[max-R]++\),删除时\(cnt[max-R]--\),查询时只要查\(cnt[min-L+K]\)的大小即可。
注意到\(max-R\)有可能小于\(0\),故我们要将所有下标加上\(n\)。
时间复杂度也是\(O(nk)\)。
故,一次分治的复杂度为\(O(nk)\),一共\(\log n\)层,总的时间复杂度为\(O(nk\log n)\)。
代码如下:
#include<bits/stdc++.h>
#define debug(...) std::cerr<<#__VA_ARGS__<<" : "<<__VA_ARGS__<<std::endl
using ll=long long;
const int maxn=150005;
int n,k;
int a[maxn],mx[maxn],mn[maxn],cnt[maxn+maxn];
ll divide(int l,int r) {
if(l==r) return 1ll;
int m=l+r>>1; ll ret=divide(l,m)+divide(m+1,r);
mx[m]=mn[m]=a[m],mx[m+1]=mn[m+1]=a[m+1];
for(int i=m-1;i>=l;i--) {
mx[i]=std::max(mx[i+1],a[i]);
mn[i]=std::min(mn[i+1],a[i]);
}
for(int i=m+2;i<=r;i++) {
mx[i]=std::max(mx[i-1],a[i]);
mn[i]=std::min(mn[i-1],a[i]);
}
for(int L=m;L>=l;L--)
for(int K=0;K<=k;K++) {
int R=mx[L]-mn[L]+L-K;
if(m<R&&R<=r&&mx[R]<mx[L]&&mn[R]>mn[L]) ret++;//注意判断R的条件不能漏
}
for(int R=m+1;R<=r;R++)
for(int K=0;K<=k;K++) {
int L=R+K-mx[R]+mn[R];
if(l<=L&&L<=m&&mx[L]<mx[R]&&mn[L]>mn[R]) ret++;
}
int R1=m+1,R2=m+1;
for(int L=m;L>=l;L--) {
while(R2<=r&&mn[R2]>mn[L]) cnt[mx[R2]-R2+n]++,R2++;//注意:单调队列必须先加后删
while(R1<R2&&mx[R1]<mx[L]) cnt[mx[R1]-R1+n]--,R1++;
for(int K=0;K<=k;K++) ret+=cnt[mn[L]-L+K+n];
}
while(R1<R2) cnt[mx[R1]-R1+n]--,R1++;//清空cnt
int L1=m,L2=m;
for(int R=m+1;R<=r;R++) {
while(L2>=l&&mn[L2]>mn[R]) cnt[mx[L2]+L2]++,L2--;
while(L1>L2&&mx[L1]<mx[R]) cnt[mx[L1]+L1]--,L1--;
for(int K=0;K<=k;K++) ret+=cnt[R+K+mn[R]];
}
while(L1>L2) cnt[mx[L1]+L1]--,L1--;
return ret;
}
int main() {
scanf("%d%d",&n,&k);
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
printf("%lld\n",divide(1,n));
return 0;
}
浙公网安备 33010602011771号