bzoj 4566[Haoi2016]找相同字符 - 后缀数组 + 单调栈

4566: [Haoi2016]找相同字符

Time Limit: 20 Sec  Memory Limit: 256 MB

Description

给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两
个子串中有一个位置不同。

 

Input

两行,两个字符串s1,s2,长度分别为n1,n2。1 <=n1, n2<= 200000,字符串中只有小写字母

 

Output

输出一个整数表示答案

 

Sample Input

aabb
bbaa

Sample Output

10
 
 首先最简单的方法是把两个字符串接在一起,中间插入间隔符,求出height数组;
答案是 所有属于A串的后缀 和 属于B串的后缀 的 LCP 求和 
如果用 height 数组 + st表预处理 是 n^2的,这是不可接受的
所以我们可以尝试用单调栈来解决。
分别处理 A串 在 B串前 和 A串在B串后的情况。
处理到 排名为 i的串的时候,如果栈中的height > height[i], 就弹出
这样在栈顶和 当前串 之间 的串的height  必然 是 比 当前串 大的
所以这些串的贡献就等于 当前串的 height
而对于比栈顶串更靠前的串,因为栈顶串的height 是比当前串的height 小的,所以他们的贡献是在栈顶串之前就决定的,可以直接加上
用sum[i]求出区间内的有贡献串的数量就可以了。
 
 
 
  1 #include <iostream>
  2 #include <cstdio>
  3 #include <cstring>
  4 #include <algorithm>
  5 #define LL long long
  6 
  7 using namespace std;
  8 
  9 const int MAXN = 4e5 + 10;
 10 int n1, n2, n;
 11 int m;
 12 int sum[MAXN];
 13 char s[MAXN * 3];
 14 char s1[MAXN], s2[MAXN];
 15 int h[MAXN];
 16 LL ans = 0;
 17 
 18 int SA[MAXN], ra[MAXN], cur[MAXN], tp[MAXN], c[MAXN];
 19 
 20 struct s {
 21     int id;
 22     int sum;
 23 } sta[MAXN];
 24 
 25 inline LL read()
 26 {
 27     LL x = 0, w = 1; char ch = 0;
 28     while(ch < '0' || ch > '9') {
 29         if(ch == '-') {
 30             w = -1;
 31         }
 32         ch = getchar();
 33     }
 34     while(ch >= '0' && ch <= '9') {
 35         x = x * 10 + ch - '0';
 36         ch = getchar();
 37     }
 38     return x * w;
 39 }
 40 
 41 void solve(int x) 
 42 {
 43     for(int i = 1; i <= x; i++) {
 44         c[i] = 0;
 45     }
 46     for(int i = 1; i <= n; i++) {
 47         c[ra[tp[i]]]++;
 48     }
 49     for(int i = 1; i <= x; i++) {
 50         c[i] += c[i - 1];
 51     }
 52     for(int i = n; i >= 1; i--) {
 53         SA[c[ra[tp[i]]]--] = tp[i];
 54     }
 55 }
 56 
 57 void copy()
 58 {
 59     for(int i = 1; i <= n; i++) {
 60         cur[i] = ra[i];
 61     }
 62 }
 63 
 64 void suffix()
 65 {
 66     for(int i = 1; i <= n; i++) {
 67         ra[i] = char(s[i]);
 68         tp[i] = i;
 69     }
 70     solve(m = 128);
 71     for(int w = 1, p = 0; p < n; m = p, w += w) {
 72         p = 0;
 73         for(int j = n - w + 1; j <= n; j++) {
 74             tp[++p] = j;
 75         }
 76         for(int i = 1; i <= n; i++) {
 77             if(SA[i] > w) {
 78                 tp[++p] = SA[i] - w;
 79             }
 80         }
 81         solve(m);
 82         copy();
 83         ra[SA[1]] = p = 1;
 84         for(int i = 2; i <= n; i++) {
 85             if(cur[SA[i]] == cur[SA[i - 1]] && cur[SA[i] + w] == cur[SA[i - 1] + w]) {
 86                 ra[SA[i]] = p;
 87             } else {
 88                 ra[SA[i]] = ++p;
 89             }
 90         }
 91     }
 92     int k = 0;
 93     for(int i = 1; i <= n; i++) {
 94         if(k) {
 95             k--;
 96         }
 97         int j = SA[ra[i] - 1];
 98         while(s[i + k] == s[j + k]) {
 99             k++;
100         }
101         h[ra[i]] = k;
102     }
103 }
104 
105 void cal()
106 {
107     int top = 1;
108     sum[0] = 0;
109     for(int i = 1; i <= n; i++) {
110         sum[i] = sum[i - 1];
111         if(SA[i] > n1 + 1) {
112             sum[i]++; 
113         }
114     }
115     for(int i = 1; i <= n; i++) {
116         while(top > 1 && h[sta[top - 1].id] > h[i]) {
117             top--;
118         }
119         sta[top].sum = sta[top - 1].sum + (sum[i - 1] - sum[sta[top - 1].id - 1]) * h[i];
120         sta[top++].id = i;
121         if(SA[i] <= n1) {
122             ans += sta[top - 1].sum;
123         }
124     }
125     sum[0] = 0;
126     top = 1;
127     for(int i = 1; i <= n; i++) {
128         sum[i] = sum[i - 1];
129         if(SA[i] <= n1) {
130             sum[i]++; 
131         }
132     }
133     for(int i = 1; i <= n; i++) {
134         while(top > 1 && h[sta[top - 1].id] > h[i]) {
135             top--;
136         }
137         sta[top].sum = sta[top - 1].sum + (sum[i - 1] - sum[sta[top - 1].id - 1]) * h[i];
138         sta[top++].id = i;
139         if(SA[i] > n1 + 1) {
140             ans += sta[top - 1].sum;
141         }
142     }
143 }
144 int main()
145 {
146     scanf("%s", s1 + 1);
147     scanf("%s", s2 + 1);
148     n1 = strlen(s1 + 1), n2 = strlen(s2 + 1);
149     for(int i = 1; i <= n1; i++) {
150         s[i] = s1[i];
151     }
152     s[n1 + 1] = char('z' + 1);
153     for(int i = 1; i <= n2; i++) {
154         s[i + n1 + 1] = s2[i];
155     }
156     n = strlen(s + 1);
157     suffix();
158     cal(); 
159     printf("%lld\n", ans);
160     return 0;
161 }
162 
163 /*
164 
165 aabb
166 bbaa
167 
168 */
View Code

 

 
posted @ 2018-03-24 14:09  大财主  阅读(304)  评论(0编辑  收藏  举报