【ARC099D】 Eating Symbols Hard 题解 (双哈希 + 差分)

AtC 传送门:F - Eating Symbols Hard

双哈希 + 差分

Solution

1

发现要想确定 \([l,r]\) 对序列 \(A\) 的影响,必须得记录下 \(S\) 中的所有子串对 \(A\) 操作后的状态。进一步地,为了更好处理每一个 \([l,r]\),自然会想到使用差分的方法——记录下 \([1,i]\ (i \in [1,\left\vert S \right\vert])\)

但是上述想法暂时看起来仍有许多缺陷。注意到,我们的目的是确定 \(A\) 序列在不同情况下的状态,所以这就引导我们要去使用哈希甚至双哈希进行维护或记录。

为什么哈希能实现上述的差分呢?其实,稍加思考即可发现,即 \(h(i)\) 表示按照 \(S\) 的子序列 \([1,i]\) 对初始为 \(0\) 的序列 \(A\) 进行操作后 \(A\) 序列的哈希值。那么 \(h(r)-h(l-1)\) 实际上就是将前 \(i-1\) 个操作造成的影响抵消掉。因为真正对序列 \(A\) 造成影响的操作就是对它元素进行加减,而对元素的加减操作显然是可以通过相减抵消的。所以最终只留下第 \(l\)\(r\) 个操作对序列的影响。

2

回过头来,题目要求统计的合法区间 \([l,r]\) 貌似就是满足 \(h(r)-h(l-1)=h(n)\) 的区间。但是,这里忽略掉了一点,就是在 \(h(r)-h(l-1)\) 后,这个差值相比 \(h(n)\) 实质上会发生一个\(base\) 进制上的位移。简单解释一下,统计 \([1,n]\) 区间的哈希值时,下标 \(p\) 是从 \(0\) 开始的,第一个操作乘上的进制是 \(base^0\),然而,统计 \([l,r]\) 区间的哈希值时,下标 \(p\) 经过了前 \(l-1\) 次操作,不妨记 \(P_i\) 表示第 \(i\) 次操作时 \(p\) 的下标,那么统计区间 \([l,r]\) 时第一次乘上进制的就是 \(base^{P_{l}}\)

所以,为了消除后者比前者多出的从第 \(0\) 位到第 \(l - 1\) 位的位移,正确的柿子是 \(\dfrac{h(r)-h(l-1)}{base^{P_{l-1}}} = h(n)\)

3

知道了如何判断某一子区间是否合法之后,再考虑如何将它放进代码里实现。

看到时间限制和 \(\left\vert S \right\vert\) 的大小限制,发现统计时似乎只允许 \(\mathcal{\text{O}}(n)\) 的复杂度。这就引导我们想到一个大概的统计方法:枚举 \(l\) 遍历 \(1\)\(\left\vert S \right\vert\),然后统计合法的区间右端点 \(r\) 的数量。进一步地,这样一来 \(h(n)\)\(h(l-1)\) 已知,根据推出的柿子就发现 \(h(r)\) 已经被唯一确定了,即 \(h(r)=h(n)\times base^{P_{l-1}} + h(l-1)\),所以一开始计算 \(h(i)\) 的时候开一个 map 把它们按照哈希值统计数量即可。

至此,此题被完美解决, 复杂度可过。

4

还要再提一个代码实现的时候需要注意的点。题目中已述,\(p\) 有可能为负数。那么此时乘上的进制就成了 \(base^{-\left\vert p \right\vert}\),然后简单推一波柿子:\(base^{-\left\vert p \right\vert}=(base^{-1})^{\left\vert p\right\vert}\)。而 \(base^{-1}\) 是什么?\(base\) 在取模某个模数意义下的逆元。因为模数我们一定会设为一个素数,故直接用快速幂算出其逆元即可。

需要双哈希。

Code

#include<bits/stdc++.h>
using namespace std;

#define rep(i, a, b) for(register int i = a; i <= b; ++i)
typedef long long ll;
const int maxn = 250005;
const int mod1 = 1e9 + 7, mod2 = 998244353;
int n, p[maxn], bs = 121;
char c[maxn];
map<pair<int, int>, int> mp;
pair<int, int> hsh[maxn];
ll ans;

inline int pw(int x, int p, int mod){
	int res = 1;
	while(p){
		if(p & 1) 
			res = 1ll * res * x % mod;
		p /= 2, x = 1ll * x * x % mod;
	} return res;
}
inline pair<int, int> gh(int x, int y){
	int a, b; a = b = x;
	if(y < 0) 
		a = pw(a, mod1 - 2, mod1), b = pw(b, mod2 - 2, mod2), y *= -1; 
	return make_pair(pw(a, y, mod1), pw(b, y, mod2));
}

inline pair<int, int> add(pair<int, int> a, pair<int, int> b){
	return make_pair((a.first + b.first) % mod1, (a.second + b.second) % mod2);
}
inline pair<int, int> mns(pair<int, int> a, pair<int, int> b){
	return make_pair((a.first + mod1 - b.first) % mod1, (a.second + mod2 - b.second) % mod2);
}
inline pair<int, int> tim(pair<int, int> a, pair<int, int> b){
	return make_pair((1ll * a.first * b.first) % mod1, (1ll * a.second * b.second) % mod2);
}

int main(){
	scanf("%d", &n); 
	rep(i, 1, n){ 
		char c; cin >> c; 
		p[i] = p[i - 1];
		if(c == '+') hsh[i] = add(hsh[i - 1], gh(bs, p[i]));
		if(c == '-') hsh[i] = mns(hsh[i - 1], gh(bs, p[i]));
		if(c == '<') hsh[i] = hsh[i - 1], p[i] -= 1;
		if(c == '>') hsh[i] = hsh[i - 1], p[i] += 1;
		mp[hsh[i]] += 1;
	}
	rep(i, 1, n)
		ans += mp[add(hsh[i - 1], tim(hsh[n], gh(bs, p[i - 1])))], 
		mp[hsh[i]] -= 1;
	printf("%lld\n", ans);
	return 0;
}

感谢阅读。

posted @ 2022-07-20 08:25  pldzy  阅读(55)  评论(0)    收藏  举报