ABC247E Max Min 题解

ABC247E Max Min Solution

更好的阅读体验戳此进入

题面

给定数列 $ A_n $,给定 $ X, Y $,我们定义数对 $ (L, R) $ 满足 $ 1 \le L \le R \le n $,且数列 $ A_L, A_{L + 1}, \cdots, A_R $ 满足最大值为 $ X $,最小值为 $ Y $,求有多少种满足条件的数对。

Solution

关于本题有一些较为显然的 $ O(n \log ^2 n) $ 或 $ O(n \log n) $ 或 $ O(n) $ 的做法,前两个大概对应着树套树和线段树,然后后者大概是维护前一个等于 $ X $ 等于 $ Y $,在 $ (X, Y) $ 之间的啊等等一堆东西然后跑一遍,这里就不多赘述了,网上相关做法很多。这里主要提供一个机房巨佬 @Zpair 想出来的人类智慧容斥的做法与推导。

首先我们考虑这段区间里一定不能包含在 $ (-\infty, Y) \cup (X, +\infty) $ 内的数,所以我们考虑让所有在这个区间内的数作为分割点,这样最终会让我们的序列被分成很多个子序列,这里我们不难发现,合法的区间一定不会有跨越两个子区间的,一定都是在子区间内的子子区间。

所以这时候我们只需要对于每个划分出来的子区间考虑其中的合法区间即可,不难想到我们要找的区间合法当且仅当里面包含了至少一个 $ X $ 和至少一个 $ Y $,为了方便记录这里我们令等于 $ X $ 的点为 $ 1 $,等于 $ Y $ 的点为 $ -1 $,在 $ (X, Y) $ 之间的点为 $ 0 $,也就是我们要找的就是每个子区间里包含 $ -1 $ 和 $ 1 $ 的区间由多少个。这里我们令包含 $ 1 $ 的区间的集合为 $ A $,包含 $ -1 $ 的为 $ B $,令子区间内的所有区间的集合为 $ \Omega $,不难想到我们比较好求的是,不包含 $ 1 $ 和 $ -1 $ 的区间,不包含 $ 1 $ 的区间,不包含 $ -1 $ 的区间,以及全集,分别对应着 $ \overline{A} \cap \overline{B} \(,\) \overline{A} \(,\) \overline{B} \(,\) \Omega $。我们的目标是求出 $ A \cap B $,所以我们希望通过前面的四个求出答案,不难想到如下推导:

\[A \cap B = \Omega - \overline{A} - \overline{B} + \overline{A} \cap \overline{B} \]

此时考虑如何维护这四个值,显然我们可以考虑把子区间再次进行分割,遍历三次,分别由 $ 1 $ 分割,由 $ -1 $ 分割,由 $ -1 $ 或 $ 1 $ 分割,也就分别对应着式子中的三个区间的大小,然后对于每个子区间维护一遍即可,最后就是一个大常数的线性求解。

然后还有个点就是如果 $ X = Y $ 的话那么其实际上是既是 $ 1 $ 又是 $ -1 $,需要特判一下。

Code

#define _USE_MATH_DEFINES
#include <bits/extc++.h>

#define PI M_PI
#define E M_E
#define npt nullptr
#define SON i->to
#define OPNEW void* operator new(size_t)
#define ROPNEW(arr) void* Edge::operator new(size_t){static Edge* P = arr; return P++;}

using namespace std;
using namespace __gnu_pbds;

mt19937 rnd(random_device{}());
int rndd(int l, int r){return rnd() % (r - l + 1) + l;}
bool rnddd(int x){return rndd(1, 100) <= x;}

typedef unsigned int uint;
typedef unsigned long long unll;
typedef long long ll;
typedef long double ld;

template< typename T = int >
inline T read(void);

int N, X, Y;
int a[210000];
int cnt0(0), cnt1(0), cnt_1(0), l(1), r(-1);
ll ans(0);

ll GetC(int n, int m){
    if(n == 1)return 1;
    if(!n || m > n)return 0;
    ll ret(1);
    for(int i = n; i >= n - m + 1; --i)ret *= i;
    for(int i = m; i >= 1; --i)ret /= i;
    return ret + n;
}

int main(){
    N = read(), X = read(), Y = read();
    for(int i = 1; i <= N + 1; ++i){
        a[i] = i <= N ? read() : INT_MAX;
        a[i] = a[i] == X && a[i] == Y ? -114514 : (a[i] == X ? 1 : (a[i] == Y ? -1 : (Y <= a[i] && a[i] <= X ? 0 : 114514)));
        if(a[i] != 114514){r = i; continue;}
        if(!~r){l = i + 1; continue;}
        ll cur(0);
        cur += GetC(r - l + 1, 2);
        int ll(l), rr(-1);
        for(int j = l; j <= r; ++j){
            if(a[j] != -1 && a[j] != -114514){rr = j; continue;}
            if(!~rr){ll = j + 1; continue;}
            cur -= GetC(rr - ll + 1, 2);
            ll = j + 1;
        }if(ll <= rr)cur -= GetC(rr - ll + 1, 2);
        ll = l, rr = -1;
        for(int j = l; j <= r; ++j){
            if(a[j] != 1 && a[j] != -114514){rr = j; continue;}
            if(!~rr){ll = j + 1; continue;}
            cur -= GetC(rr - ll + 1, 2);
            ll = j + 1;
        }if(ll <= rr)cur -= GetC(rr - ll + 1, 2);
        ll = l, rr = -1;
        for(int j = l; j <= r; ++j){
            if(!a[j]){rr = j; continue;}
            if(!~rr){ll = j + 1; continue;}
            cur += GetC(rr - ll + 1, 2);
            ll = j + 1;
        }if(ll <= rr)cur += GetC(rr - ll + 1, 2);
        ans += cur;
        l = i + 1;
    }
    printf("%lld\n", ans);

    fprintf(stderr, "Time: %.6lf\n", (double)clock() / CLOCKS_PER_SEC);
    return 0;
}

template < typename T >
inline T read(void){
    T ret(0);
    short flag(1);
    char c = getchar();
    while(c != '-' && !isdigit(c))c = getchar();
    if(c == '-')flag = -1, c = getchar();
    while(isdigit(c)){
        ret *= 10;
        ret += int(c - '0');
        c = getchar();
    }
    ret *= flag;
    return ret;
}

然后因为我是完全按照这个思路实现的所以上面的可能非常丑,所以这里也贴一个 @Zpair 的优美实现。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
int l1,l2,l3,l4,n,x,y;ll ans;
ll get(int x){return (ll)x*(x+1)/2;}
int main(){
	cin>>n>>x>>y;int v;
	for(int i=1;i<=n;++i){
		scanf("%d",&v);
		if(v>x||v<y){ans+=get(l1)-get(l2)-get(l3)+get(l4),l1=l2=l3=l4=0;continue;}
		if(v!=x&&v!=y)l1++,l2++,l3++,l4++;
		if(v==x&&v==y)l1++,ans+=-get(l2)-get(l3)+get(l4),l2=l3=l4=0;
		if(v==x&&v!=y)l1++,l3++,ans+=-get(l2)+get(l4),l2=l4=0;
		if(v!=x&&v==y)l1++,l2++,ans+=-get(l3)+get(l4),l3=l4=0;
	}
	ans+=get(l1)-get(l2)-get(l3)+get(l4);
	cout<<ans;
}

UPD

update-2022_10_24 初稿

posted @ 2022-12-02 20:01  Tsawke  阅读(33)  评论(0)    收藏  举报