P3181 [HAOI2016]找相同字符

\(\color{#0066ff}{ 题目描述 }\)

给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两个子串中有一个位置不同。

\(\color{#0066ff}{输入格式}\)

两行,两个字符串s1,s2,长度分别为n1,n2。1 <=n1, n2<= 200000,字符串中只有小写字母

\(\color{#0066ff}{输出格式}\)

输出一个整数表示答案

\(\color{#0066ff}{输入样例}\)

aabb
bbaa

\(\color{#0066ff}{输出样例}\)

10

\(\color{#0066ff}{数据范围与提示}\)

none

\(\color{#0066ff}{ 题解 }\)

考虑把两个串拼起来,中间隔一个无关字符

我们每次找到一个合法的LCP,显然会产生LCP所有字串的贡献,但是这样会重复

我们定住一个端点,也就是让它产生LCP长度的贡献,这样在不同后缀中一端不同,相同后缀中另一端不同

怎么统计呢?

考虑单步容斥,用拼好的串的贡献-两个串内部贡献

#include<bits/stdc++.h>
#define LL long long
LL in() {
	char ch; LL x = 0, f = 1;
	while(!isdigit(ch = getchar()))(ch == '-') && (f = -f);
	for(x = ch ^ 48; isdigit(ch = getchar()); x = (x << 1) + (x << 3) + (ch ^ 48));
	return x * f;
}
const int inf = 0x7fffffff;
const int maxn = 4e5 + 5;
struct SA {
protected:
	int x[maxn], y[maxn], rk[maxn], sa[maxn], c[maxn], st[maxn];
	int top, n, m;
	LL l[maxn], r[maxn], h[maxn];
public:
	void operator () (char *s, int len) {
		n = len, m = 122;
		for(int i = 1; i <= n; i++) c[x[i] = s[i]]++;
		for(int i = 1; i <= m; i++) c[i] += c[i - 1];
		for(int i = n; i >= 1; i--) sa[c[x[i]]--] = i;
		for(int k = 1; k <= n; k <<= 1) {
			int num = 0;
			for(int i = n - k + 1; i <= n; i++) y[++num] = i;
			for(int i = 1; i <= n; i++) if(sa[i] > k) y[++num] = sa[i] - k;
			for(int i = 1; i <= m; i++) c[i] = 0;
			for(int i = 1; i <= n; i++) c[x[i]]++;
			for(int i = 1; i <= m; i++) c[i] += c[i - 1];
			for(int i = n; i >= 1; i--) sa[c[x[y[i]]]--] = y[i], y[i] = 0;
			std::swap(x, y);
			x[sa[1]] = 1, num = 1;
			for(int i = 2; i <= n; i++) 
				x[sa[i]] = (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k])? num : ++num;
			if(n == num) break;
			m = num;
		}
		for(int i = 1; i <= n; i++) rk[i] = x[i];
		int H = 0;
		for(int i = 1; i <= n; i++) {
			if(rk[i] == 1) continue;
			if(H) H--;
			int j = sa[rk[i] - 1];
			while(i + H <= n && j + H <= n && s[j + H] == s[i + H]) H++;
			h[rk[i]] = H;
		}
	}
	LL getans() {
		LL ans = 0;
		h[0] = h[n + 1] = -inf;
		st[top = 1] = 0;
		for(int i = 1; i <= n; i++) {
			while(h[i] <= h[st[top]]) top--;
			l[i] = st[top];
			st[++top] = i;
		}
		st[top = 1] = n + 1;
		for(int i = n; i >= 1; i--) {
			while(h[i] < h[st[top]]) top--;
			r[i] = st[top];
			st[++top] = i;
		}
		for(LL i = 1; i <= n; i++) ans += (r[i] - i) * (i - l[i]) * h[i];
		return ans;
	}
}a, b, c;
char s[maxn], t[maxn];
int main() {
	scanf("%s", s + 1);
	scanf("%s", t + 1);
	int lens = strlen(s + 1);
	int lent = strlen(t + 1);
	a(s, lens);
	b(t, lent);
	s[++lens] = '#';
	for(int i = 1; i <= lent; i++) s[++lens] = t[i];
	c(s, lens);
	printf("%lld\n", c.getans() - a.getans() - b.getans());
	return 0;
}
posted @ 2019-01-13 11:50  olinr  阅读(220)  评论(0编辑  收藏  举报