【BZOJ3238】[AHOI2013]差异

【BZOJ3238】[AHOI2013]差异

题面

给定字符串\(S\),令\(T_i\)表示以它从第\(i\)个字符开始的后缀。求

\[\sum_{1\leq i<j\leq n}len(T_i)+len(T_j)-2*lcp(T_i,T_j) \]

其中\(len(a)\)表示串\(a\)的长度,\(lcp(a,b)\)表示串\(a,b\)的最长公共前缀

题解

把这个式子看作两边分开求:

Part1:

\[\sum_{1\leq i<j\leq n}len(T_i)+len(T_j)\\ \Leftrightarrow\sum_{i=1}^{n-1}\sum_{j=i+1}^ni+j\\ =\sum_{i=1}^{n-1}i*(n-i)+\frac{(i+1+n)*(n-i)}2 \]

其实现在你就可以\(O(n)\)地求了,但是因为我出(kan)于(le)美(ti)观(jie)

发现它其实可以化成这样:

\[\frac {(n-1)*n*(n+1)}2 \]

Part2:

一看到后缀当然是\(sa\)

由后缀数组的性质,排名为分别为\(i,j\)的后缀,\(lcp_{i,j}=\min\limits_{k=i+1}^jheight_k\)

我们将所有高度数组排成一排,

假设中间的第\(i\)个数是\(l-r\)中最小的

则它的贡献就是\((i-l+1)*(r-i+1)\)

我们处理出来对\(i\)所有的\(l,r\)是不是就做出来了呢

这不就是一个单调栈的经典应用吗?

而这个题目中因为一些细节问题我的\(l\)表示小于\(i\)的第一个,\(r\)表示小于等于\(i\)的第一个

详见代码:

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring> 
#include <cmath> 
#include <algorithm>
using namespace std; 
inline int gi() {
	register int data = 0, w = 1;
	register char ch = 0;
	while (!isdigit(ch) && ch != '-') ch = getchar(); 
	if (ch == '-') w = -1, ch = getchar();
	while (isdigit(ch)) data = 10 * data + ch - '0', ch = getchar();
	return w * data; 
}
const int MAX_N = 5e5 + 5; 
int N; char a[MAX_N]; 
int sa[MAX_N], lcp[MAX_N], rnk[MAX_N]; 
void GetSA() { 
#define cmp(i, j, k) (y[i] == y[j] && y[i + k] == y[j + k]) 
	static int x[MAX_N], y[MAX_N], bln[MAX_N]; 
	int M = 122; 
	for (int i = 1; i <= N; i++) bln[x[i] = a[i]]++; 
	for (int i = 1; i <= M; i++) bln[i] += bln[i - 1]; 
	for (int i = N; i >= 1; i--) sa[bln[x[i]]--] = i; 
	for (int k = 1; k <= N; k <<= 1) { 
		int p = 0; 
		for (int i = 0; i <= M; i++) y[i] = 0; 
		for (int i = N - k + 1; i <= N; i++) y[++p] = i; 
		for (int i = 1; i <= N; i++) if (sa[i] > k) y[++p] = sa[i] - k; 
		for (int i = 0; i <= M; i++) bln[i] = 0; 
		for (int i = 1; i <= N; i++) bln[x[y[i]]]++; 
		for (int i = 1; i <= M; i++) bln[i] += bln[i - 1]; 
		for (int i = N; i >= 1; i--) sa[bln[x[y[i]]]--] = y[i]; 
		swap(x, y); x[sa[1]] = p = 1; 
		for (int i = 2; i <= N; i++) x[sa[i]] = cmp(sa[i], sa[i - 1], k) ? p : ++p; 
		if (p >= N) break; 
		M = p; 
	} 
}
void GetLcp() { 
	for (int i = 1; i <= N; i++) rnk[sa[i]] = i; 
	for (int i = 1, j = 0; i <= N; i++) { 
		if (j) --j; 
		while (a[i + j] == a[sa[rnk[i] - 1] + j]) ++j; 
		lcp[rnk[i]] = j; 
	} 
}
typedef long long ll; 
int lp[MAX_N], rp[MAX_N], stk[MAX_N], top; 
int main () {
	scanf("%s", a + 1); N = strlen(a + 1); 
	GetSA(); GetLcp(); 
	ll ans = 0; 
	//for (int i = 1; i < N; i++) ans += 1ll * i * (N - i) + 1ll * (i + 1 + N) * (N - i) / 2ll; 
	ans = 1ll * N * (N + 1) * (N - 1) / 2ll; 
	stk[0] = 1; 
	for (int i = 2; i <= N; i++) {
		while (top > 0 && lcp[stk[top]] >= lcp[i]) --top; 
		lp[i] = i - stk[top], stk[++top] = i; 
	} 
	top = 0, stk[0] = N + 1; 
	for (int i = N; i >= 2; i--) { 
		while (top > 0 && lcp[stk[top]] > lcp[i]) --top; 
		rp[i] = stk[top] - i, stk[++top] = i; 
	}
	for (int i = 2; i <= N; i++) ans -= 2ll * lcp[i] * lp[i] * rp[i]; 
	printf("%lld\n", ans); 
	return 0; 
} 
posted @ 2019-01-22 20:02  heyujun  阅读(347)  评论(1编辑  收藏  举报