CF1223F

简要题解:考虑按照题意模拟,能消除就消除,否则入栈。每维护一个元素的加入后,用哈希维护当前栈的状态,最后就遍历哈希表(假设它是 $x$ 个位置的栈序列),拿所有 $C_{x}^{2}$ 求个和即可。

解题思路

考虑先借助可消除序列的构造方式对括号序列的形态进行观察:

  1. 假若 A 是一个可消除序列,那么 cAc 也是。

    其中 c 是一个字符,cAc 表示将两个元素 c 分别在前后与数组 A 拼接。

  2. 假若 A,B 都是的可消除序列,那么 AB 也是可消除序列。

容易发现,这是一个区间 DP 的形式,故而我们得到了一个 $O(n^3)$ 的做法。

但仔细考虑可消除序列的性质,不难得出一种复杂度更优的 $O(n^2)$ 的做法:我们枚举左端点,然后逐个贪心地加入下一个字符。其中,我们用栈维护序列,假设能弹栈就弹栈,不能弹就加入。如果某个时刻栈空了,那么得到的可消除序列数就加一。

这是一个比区间 DP 更有前途的做法,应为我们完全地脱离了区间 DP 枚举区间至少 $O(n^2)$ 的复杂度。

我们尝试在 $O(n^2)$ 做法的基础上抽离模型去简约地考虑。对于一个序列 1 1 2 2 3 3,不妨把 1 12 23 3 两两一组看成 $3$ 个整体去考虑。合法的括号序列数就是 $C_{3}^{2}$(组合意义就是选取两个整体分别作为左右端点)。

为了更清晰地发倔性质,我们不妨在上述括号序列最左边以及最右边各添加一个 $4$:4 1 1 2 2 3 3 4,此时的合法序列在原来的基础上还增加了 4 1 1 2 2 3 3 4

由上述的举例易得,从数组的起始位置开始贪心地维护栈,$\forall i,j\in [1,n]$,假若两个时刻的栈序列相同,那么它们之间的这段就是一个可消除序列(因为 $i,j$ 两者之间的位置经历了加栈和弹栈的过程,所以栈序列相同就证明中间的所有字符都能两两消除)。

$\forall i\in [1,n]$,用哈希维护一下栈序列,之后计数即可。

代码实现:

上文提及的 $O(n^2)$ 做法:

#include <bits/stdc++.h>
#define FL(i, a, b) for(int i = (a); i <= (b); ++i)
#define FR(i, a, b) for(int i = (a); i >= (b); --i)
using namespace std;
const int N = 2e6 + 10, N1 = 1510;
int n, tp, ans, f[N1][N1], s[N], st[N];
void solve(){
    scanf("%d", &n);
    FL(i, 1, n) scanf("%d", &s[i]);
    FL(i, 1, n){
        tp = 0;
        FL(j, i, n){
            if(!tp || st[tp] != s[j]) st[++tp] = s[j];
            else tp--;
            ans += (!tp);
        }
    }
    printf("%d\n", ans);
}
int main(){
    int T; scanf("%d", &T);
    while(T--) solve();
    return 0;
}

哈希的 $O(n)$ 做法:

#include <bits/stdc++.h>
#define FL(i, a, b) for(int i = (a); i <= (b); ++i)
#define FR(i, a, b) for(int i = (a); i >= (b); --i)
using namespace std;
typedef unsigned long long ull;
const int N = 3e5 + 10;
const unsigned long long base = 2333333;
unordered_map<ull, int> mp;
int n, tp, a[N], s[N], st[N];
ull h[N]; long long ans;
void solve(){
    scanf("%d", &n);
    FL(i, 1, n) scanf("%d", &s[i]);
    unordered_map<ull, int>().swap(mp);
    mp[0] = 1, ans = tp = 0;
    FL(i, 1, n){
        if(tp && st[tp] == s[i]) tp--;
        else{
            st[++tp] = s[i];
            h[tp] = h[tp - 1] * base + (ull)s[i];
        }
        mp[h[tp]]++;
    }
    for(auto x: mp) ans += 1ll * x.second * (x.second - 1) / 2;
    printf("%lld\n", ans);
}
int main(){
    int T; scanf("%d", &T);
    while(T--) solve();
    return 0;
}
posted @ 2023-10-24 08:55  徐子洋  阅读(13)  评论(0)    收藏  举报  来源