字符串

懒得认真写博客了...

圈圈

求每次把字符串全部 + 1 之后的最小表示法。

就是那样,先Hash。

+ 1的时候考虑有哪些地方变成0了,在这些位置中O(logn)比较。

没有变成0的话就是上一次的答案。

  1 #include <cstdio>
  2 #include <cstring>
  3 #include <vector>
  4 #include <algorithm>
  5 
  6 typedef unsigned long long uLL;
  7 
  8 const int N = 50010, B = 1e9 + 7;
  9 
 10 int s[N << 1], n, m, k;
 11 uLL H[N << 1], po[N << 1];
 12 std::vector<int> pos[N];
 13 
 14 inline void gethash(int n) {
 15     H[0] = s[0];
 16     po[0] = 1;
 17     for(int i = 1; i < n * 2; i++) {
 18         po[i] = po[i - 1] * B;
 19         H[i] = H[i - 1] * B + s[i];
 20     }
 21     return;
 22 }
 23 inline uLL Hash(int l, int r) {
 24     if(l == 0) {
 25         return H[r]; /// error : space
 26     }
 27     return H[r] - H[l - 1] * po[r - l + 1];
 28 }
 29 
 30 inline int getmin() {
 31     int i = 0, j = 1;
 32     while(i < n && j < n) {
 33         int k = 0;
 34         while(s[i + k] == s[j + k] && k < n) {
 35             k++;
 36         }
 37         if(k == n) {
 38             return 0;
 39         }
 40         if(s[i + k] > s[j + k]) {
 41             i += k + 1;
 42             if(i == j) {
 43                 i++;
 44             }
 45         }
 46         else {
 47             j += k + 1;
 48             if(i == j) {
 49                 j++;
 50             }
 51         }
 52     }
 53     return std::min(i, j);
 54 }
 55 
 56 inline bool great(int i, int j, int t) {
 57     int l = 0, r = n, mid;
 58     while(l < r) {
 59         mid = (l + r) >> 1;
 60         if(Hash(i, i + mid) == Hash(j, j + mid)) {
 61             l = mid + 1;
 62         }
 63         else {
 64             r = mid;
 65         }
 66     }
 67     return (s[i + r] + t) % m > (s[j + r] + t) % m;
 68 }
 69 
 70 int main() {
 71     freopen("in.in", "r", stdin);
 72     freopen("my.out", "w", stdout);
 73     scanf("%d%d%d", &n, &m, &k);
 74     k--;
 75     for(int i = 0; i < n; i++) {
 76         scanf("%d", &s[i]);
 77         pos[s[i]].push_back(i);
 78     }
 79 
 80     memcpy(s + n, s, n * sizeof(int));
 81     gethash(n);
 82 
 83     int t = getmin();
 84     printf("%d\n", s[t + k]);
 85     //printf("%d\n", t);
 86     for(int i = 1; i < m; i++) {
 87         if(pos[m - i].size()) {
 88             t = pos[m - i][0];
 89             for(int j = 1; j < pos[m - i].size(); j++) {
 90                 if(great(t, pos[m - i][j], i)) {
 91                     t = pos[m - i][j];
 92                 }
 93             }
 94         }
 95         printf("%d\n", (s[t + k] + i) % m);
 96         //printf("%d\n", t);
 97     }
 98 
 99     return 0;
100 }
AC代码

manacher

作用:求出字符串中,以每个位置为中心,最多能向两边扩展多少回文串。

例:  #a#b#a#

f[]:  0 1 0 3 0 1 0

两边中间都要补上#之类的东西,max(f)直接输出就是ans

模板:

 1 #include <cstdio>
 2 #include <cstring>
 3 #include <algorithm>
 4 
 5 const int N = 11000010;
 6 
 7 char s[N << 1];
 8 int f[N << 1];
 9 
10 int main() {
11     scanf("%s", s);
12     int n = strlen(s);
13     n = (n << 1) + 1;
14     for(int i = n - 1; i >= 0; i--) {
15         if(i & 1) {
16             s[i] = s[i >> 1];
17         }
18         else {
19             s[i] = '#';
20         }
21     }
22     //printf("%s\n", s);
23     f[0] = 1;
24     int p = 0, ans = 1;
25     for(int i = 1; i < n; i++) {
26         int t = 2 * p - i;
27         if(t < 0 || i + f[t] >= p + f[p]) {
28             int j = p + f[p] - i + 1;
29             while(i + j < n && i - j >= 0 && s[i + j] == s[i - j]) {
30                 j++;
31             }
32             f[i] = j - 1;
33         }
34         else {
35             f[i] = f[t];
36         }
37         //printf("%d %d\n", i, f[i]);
38         if(i + f[i] > p + f[p]) {
39             p = i;
40         }
41         ans = std::max(ans, f[i]);
42     }
43     printf("%d", ans);
44     return 0;
45 }
mancher模板

 这个是hash替代manacher,可以求出以每个位置为结尾的最长回文串。

 1 #include <cstdio>
 2 #include <cstring>
 3 #include <algorithm>
 4 
 5 typedef unsigned long long uLL;
 6 
 7 const int N = 11000010, B = 13331;
 8 
 9 int f1[N], f2[N], n;
10 char s[N];
11 uLL H1[N], H2[N], po[N];
12 
13 inline void gethash() {
14     H1[0] = s[0];
15     po[0] = 1;
16     for(int i = 1; i < n; i++) {
17         H1[i] = H1[i - 1] * B + s[i];
18         po[i] = po[i - 1] * B;
19     }
20     H2[n - 1] = s[n - 1];
21     for(int i = n - 2; i >= 0; i--) {
22         H2[i] = H2[i + 1] * B + s[i];
23     }
24     return;
25 }
26 inline uLL Hash1(int l, int r) {
27     if(!l) {
28         return H1[r];
29     }
30     return H1[r] - H1[l - 1] * po[r - l + 1];
31 }
32 inline uLL Hash2(int l, int r) {
33     if(r == n - 1) {
34         return H2[l];
35     }
36     return H2[l] - H2[r + 1] * po[r - l + 1];
37 }
38 
39 int main() {
40     scanf("%s", s);
41     n = strlen(s);
42     gethash();
43     f1[0] = 1;                   //  01234567
44     int ans = 1;                 //  babcbabc
45     for(int i = 1; i < n; i++) { //  11313575
46         if(f1[i - 1] < i && s[i] == s[i - f1[i - 1] - 1]) {
47             f1[i] = f1[i - 1] + 2;
48         }
49         else {
50             int k = i - f1[i - 1];
51             while(Hash1(k, i) != Hash2(k, i)) {
52                 k++;
53             }
54             f1[i] = i - k + 1;
55         }
56         ans = std::max(ans, f1[i]);
57         //printf("%d %d \n", i, f1[i]);
58     }
59     printf("%d", ans);
60     return 0;
61 }
AC代码

最长双回文串。

这个很奇怪...O(n)的算法,luogu上面跑的飞起,bzoj就T。

先放洛谷AC代码吧。

 1 #include <cstdio>
 2 #include <cstring>
 3 #include <algorithm>
 4 
 5 typedef unsigned long long uLL;
 6 
 7 const int N = 11000010, B = 13331;
 8 
 9 int f1[N], f2[N], n;
10 char s[N];
11 uLL H1[N], H2[N], po[N];
12 
13 inline void gethash() {
14     H1[0] = s[0];
15     po[0] = 1;
16     for(int i = 1; i < n; i++) {
17         H1[i] = H1[i - 1] * B + s[i];
18         po[i] = po[i - 1] * B;
19     }
20     H2[n - 1] = s[n - 1];
21     for(int i = n - 2; i >= 0; i--) {
22         H2[i] = H2[i + 1] * B + s[i];
23     }
24     return;
25 }
26 inline uLL Hash1(int l, int r) {
27     if(!l) {
28         return H1[r];
29     }
30     return H1[r] - H1[l - 1] * po[r - l + 1];
31 }
32 inline uLL Hash2(int l, int r) {
33     if(r == n - 1) {
34         return H2[l];
35     }
36     return H2[l] - H2[r + 1] * po[r - l + 1];
37 }
38 
39 int main() {
40     scanf("%s", s);
41     n = strlen(s);
42     gethash();                   //  01234567
43     f1[0] = f2[n - 1] = 1;       //  babcbabc
44                                  //  75353111
45     for(int i = 1; i < n; i++) { //  11313575
46         if(f1[i - 1] < i && s[i] == s[i - f1[i - 1] - 1]) {
47             f1[i] = f1[i - 1] + 2;
48         }
49         else {
50             int k = i - f1[i - 1];
51             while(Hash1(k, i) != Hash2(k, i)) {
52                 k++;
53             }
54             f1[i] = i - k + 1;
55         }
56     }
57     for(int i = n - 2; i >= 0; i--) {
58         if(i + f2[i + 1] + 1 < n && s[i] == s[i + f2[i + 1] + 1]) {
59             f2[i] = f2[i + 1] + 2;
60         }
61         else {
62             int k = i + f2[i + 1];
63             while(Hash1(i, k) != Hash2(i, k)) {
64                 k--;
65             }
66             f2[i] = k - i + 1;
67         }
68     }
69 
70     int ans = 2;
71     for(int i = 1; i < n; i++) {
72         ans = std::max(ans, f1[i - 1] + f2[i]);
73     }
74     printf("%d", ans);
75     return 0;
76 }
洛谷P4555 AC代码

 咳咳,问题查出来了,N忘了改,开了1100 0000的uLL数组。。。BZ MLE显示TLE


POJ1509

poj的毒瘤也不是第一次见了,莫名其妙的WA,然后就A...

最小表示法模板题。

 1 #include <cstdio>
 2 #include <cstring>
 3 #include <algorithm>
 4 
 5 const int N = 10010;
 6 
 7 char s[N << 1];
 8 
 9 int main() {
10     int T;
11     scanf("%d", &T);
12     while(T--) {
13         scanf("%s", s);
14         int n = strlen(s);
15         memcpy(s + n, s, n * sizeof(char));
16         int i = 1, j = 0;
17         while(i < n && j < n) {
18             int k = 0;
19             while(k < n && s[i + k] == s[j + k]) {
20                 k++;
21             }
22             if(k == n) {
23                 break;
24             }
25             if(s[i + k] > s[j + k]) {
26                 i += k + 1;
27                 if(i == j) {
28                     i++;
29                 }
30             }
31             else {
32                 j += k + 1;
33                 if(j == i) {
34                     j++;
35                 }
36             }
37         }
38         printf("%d\n", std::min(i, j) + 1);
39     }
40     return 0;
41 }
AC代码

 APIO2014 回文串

PAM业界毒瘤!!!

这句话不得不说,毕竟理解它花了几分钟,但是两天都打不出来...

最后直接摘抄模板,不管了。注意!insert开新点那里不能先给tr[p][f]赋值。

 1 #include <cstdio>
 2 #include <cstring>
 3 #include <algorithm>
 4 
 5 typedef long long LL;
 6 
 7 const int N = 300010;
 8 
 9 char s[N];
10 int n;
11 
12 struct PAM {
13     int tr[N][26], cnt[N], fail[N], len[N], num[N];
14     int top, last;
15     inline void init() { // !!!
16         len[0] = 0;
17         len[1] = -1;
18         fail[0] = 1;
19         fail[1] = 1;
20         last = 0;
21         top = 1;
22         return;
23     }
24     inline int getfail(int d, int x) {
25         while(s[d - len[x] - 1] != s[d]) {
26             x = fail[x];
27         }
28         return x;
29     }
30     inline void insert(int d) {
31         int f = s[d] - 'a';
32         int p = getfail(d, last);
33         if(!tr[p][f]) {
34             ++top;
35             len[top] = len[p] + 2; // 长度
36             fail[top] = tr[getfail(d, fail[p])][f]; // 最长 靠右 子回文串
37             num[top] = num[fail[top]] + 1; // 这个串内靠右 回文串数
38             tr[p][f] = top; // 两端 + f 转移
39         }
40         last = tr[p][f]; // 末尾最长回文串
41         cnt[last]++; // 该回文串出现次数
42     }
43     inline void count() { // 统计 cnt
44         for(int i = top; i >= 0; i--) {
45             cnt[fail[i]] += cnt[i];
46         }
47     }
48 }pam;
49 
50 int main() {
51     scanf("%s", s);
52     n = strlen(s);
53     pam.init();
54     for(int i = 0; i < n; i++) {
55         pam.insert(i);
56     }
57     pam.count();
58 
59     LL ans = 0;
60     for(int i = 2; i <= pam.top; i++) {
61         ans = std::max(ans, 1ll * pam.len[i] * pam.cnt[i]);
62     }
63     printf("%lld", ans);
64     
65     return 0;
66 }
AC代码

有个小问题就是char数组的-1位是空...不知道为什么。


SHOI2011 双倍回文

这题用回文自动机秒...

先打了个裸露在外的暴力,T了一个点,然后剪个枝就A了,跑的还贼快...

 1 #include <cstdio>
 2 #include <cstring>
 3 #include <algorithm>
 4 
 5 const int N = 500010;
 6 
 7 char s[N];
 8 int n;
 9 
10 struct PAM {
11     int tr[N][26], cnt[N], len[N], num[N], fail[N];
12     int top, last;
13     inline void init() {
14         len[0] = 0;
15         len[1] = -1;
16         fail[0] = 1;
17         fail[1] = 1;
18         top = 1;
19         last = 0;
20         return;
21     }
22     inline int getfail(int d, int x) {
23         while(s[d - len[x] - 1] != s[d]) {
24             x = fail[x];
25         }
26         return x;
27     }
28     inline void insert(int d) {
29         int f = s[d] - 'a';
30         int p = getfail(d, last);
31         if(!tr[p][f]) {
32             ++top;
33             fail[top] = tr[getfail(d, fail[p])][f];
34             len[top] = len[p] + 2;
35             num[top] = num[fail[top]] + 1;
36             tr[p][f] = top;
37         }
38         last = tr[p][f];
39         cnt[last]++;
40         return;
41     }
42     inline void count() {
43         for(int i = top; i >= 0; i--) {
44             cnt[fail[i]] += cnt[i];
45         }
46         return;
47     }
48 }pam;
49 
50 int main() {
51     pam.init();
52     scanf("%d", &n);
53     scanf("%s", s);
54     for(int i = 0; i < n; i++) {
55         pam.insert(i);
56     }
57     int ans = 0;
58     for(int i = pam.top; i > 1; i--) {
59         if(pam.len[i] % 4 || pam.len[i] <= ans) { 
60             continue;
61         }
62         int j = pam.len[i], p = pam.fail[i];
63         while(pam.len[p] << 1 > j) {
64             p = pam.fail[p];
65         }
66         if(pam.len[p] << 1 == j) {
67             ans = std::max(ans, j);
68         }
69     }
70 
71     printf("%d", ans);
72     return 0;
73 }
AC代码

洛谷P1872 回文串计数

仔细分析之后也可以用PAM秒。

只要记录每个位置为结尾的回文串数量,翻转之后再统计每个位置之前的所有回文串数量,乘起来即可。

 1 #include <cstdio>
 2 #include <cstring>
 3 #include <algorithm>
 4 
 5 typedef long long LL;
 6 
 7 const int N = 2010;
 8 
 9 char s[N];
10 int n, ans[N], tot;
11 
12 struct PAM {
13     int tr[N][26], len[N], cnt[N], num[N], fail[N];
14     int last, top;
15     inline void init() {
16         len[0] = 0;
17         len[1] = -1;
18         fail[0] = fail[1] = 1;
19         last = 0;
20         top = 1;
21         return;
22     }
23     PAM() {
24         init();
25     }
26     inline int getfail(int d, int x) {
27         while(s[d - len[x] - 1] != s[d]) {
28             x = fail[x];
29         }
30         return x;
31     }
32     inline void insert(int d) {
33         int f = s[d] - 'a';
34         int p = getfail(d, last);
35         if(!tr[p][f]) {
36             ++top;
37             fail[top] = tr[getfail(d, fail[p])][f];
38             len[top] = len[p] + 2;
39             num[top] = num[fail[top]] + 1;
40             tr[p][f] = top;
41         }
42         last = tr[p][f];
43         cnt[last]++;
44         ans[d] = num[last];
45         tot += num[last];
46     }
47     inline void count() {
48         for(int i = top; i >= 0; i--) {
49             cnt[fail[i]] += cnt[i];
50         }
51         return;
52     }
53     inline void clear() {
54         for(int i = 0; i <= top; i++) {
55             for(int j = 0; j < 26; j++) {
56                 tr[i][j] = 0;
57             }
58             len[i] = fail[i] = num[i] = cnt[i] = 0;
59         }
60         init();
61         return;
62     }
63 }pam;
64 
65 int main() {
66     scanf("%s", s);
67     n = strlen(s);
68     for(int i = 0; i < n; i++) {
69         pam.insert(i);
70     }
71     pam.clear();
72     std::reverse(s, s + n);
73     std::reverse(ans, ans + n);
74     tot = 0;
75     LL t = 0;
76     for(int i = 0; i < n - 1; i++) {
77         pam.insert(i);
78         t += 1ll * tot * ans[i + 1];
79         //printf("%d %d\n", tot, ans[i + 1]);
80     }
81     printf("%lld", t);
82     return 0;
83 }
AC代码

国家集训队 拉拉队排练

嗯,依旧是PAM水题。 

我的想法是把每个奇回文串搞一个结构体,排序之后快速幂。

T了最后一个点...于是跑去看题解,发现可以用桶装,继续T最后一个点...

最后发现是输入的 k 爆int了,所以不知怎地超时的……

  1 #include <cstdio>
  2 #include <cstring>
  3 #include <algorithm>
  4 
  5 typedef long long LL;
  6 
  7 const int N = 1000010;
  8 const LL MO = 19930726;
  9 
 10 char s[N];
 11 int n;
 12 LL bin[N];
 13 LL qpow(int aa, int b);
 14 
 15 struct PAM {
 16     int tr[N][26], fail[N], len[N], cnt[N], num[N];
 17     int last, top;
 18     inline void init() {
 19         len[0] = 0;
 20         len[1] = -1;
 21         fail[0] = fail[1] = 1;
 22         last = 0;
 23         top = 1;
 24         return;
 25     }
 26     PAM() {
 27         init();
 28     }
 29 
 30     inline int getfail(int d, int x) {
 31         while(s[d - len[x] - 1] != s[d]) {
 32             x = fail[x];
 33         }
 34         return x;
 35     }
 36     inline void insert(int d) {
 37         int f = s[d] - 'a';
 38         int p = getfail(d, last);
 39         if(!tr[p][f]) {
 40             ++top;
 41             fail[top] = tr[getfail(d, fail[p])][f];
 42             len[top] = len[p] + 2;
 43             num[top] = num[fail[top]] + 1;
 44             tr[p][f] = top;
 45         }
 46         last = tr[p][f];
 47         cnt[last]++;
 48         return;
 49     }
 50     inline void count() {
 51         for(int i = top; i >= 0; i--) {
 52             cnt[fail[i]] += cnt[i];
 53             if(len[i] & 1 && i > 1) {
 54                 bin[len[i]] += cnt[i];
 55             }
 56         }
 57         return;
 58     }
 59 }pam;
 60 
 61 inline LL qpow(int aa, int b) {
 62     LL a = (LL)(aa) % MO;
 63     LL ans = 1;
 64     while(b) {
 65         if(b & 1) {
 66             ans = (ans * a) % MO;
 67         }
 68         a = (a * a) % MO;
 69         b = b >> 1;
 70     }
 71     return ans;
 72 }
 73 
 74 int main() {
 75     //printf("%d", (int)(MO * MO));
 76     LL k;
 77     scanf("%d%lld", &n, &k);
 78     scanf("%s", s);
 79     for(int i = 0; i < n; i++) {
 80         pam.insert(i);
 81     }
 82     pam.count();
 83 /*
 84     LL tp = 0;
 85     for(int i = 2; i <= pam.top; i++) {
 86         tp += pam.cnt[i];
 87     }
 88     if(k > tp) {
 89         printf("-1");
 90         return 0;
 91     }
 92 */
 93     int i = n;
 94     LL ans = 1;
 95     for(int i = n; i && k; i--) {
 96         if(!bin[i]) {
 97             continue;
 98         }
 99         if(bin[i] <= k) {
100             ans = (ans * qpow(i, bin[i])) % MO;
101             k -= bin[i];
102         }
103         else {
104             ans = (ans * qpow(i, k)) % MO;
105             k = 0;
106             break;
107         }
108     }
109 
110     if(k) {
111         ans = -1;
112     }
113     printf("%lld", ans);
114     return 0;
115 }
AC代码(桶)

 PAM总结:

回文自动机,一个点代表一种回文串,支持以下功能:

1,统计所有回文串数。cnt的2 ~ top求和/途中累加num

2,统计所有本质不同回文串数。top - 1

3,统计所有本质不同回文串出现次数。count()之后的cnt

4,统计以第i位结尾的回文串个数。途中num

等等...


SAM

这个东西TM比PAM还毒瘤...我真是要被气死了。

跟PAM一样调了两天,一行一行的照抄别人的模板,就是T...

然后把数组开到2n就A了啊啊啊啊啊啊啊毒瘤!!!

模板:

 1 #include <cstdio>
 2 #include <cstring>
 3 #include <algorithm>
 4 typedef long long LL;
 5 const int N = 2000010;
 6 char s[N];
 7 int n;
 8 LL ans;
 9 struct SAM {
10     int tr[N][26], fail[N], len[N], cnt[N], bin[N], topo[N];
11     int root, top, last;
12 
13     inline void init() {
14         top = 1;
15         last = 1;
16         root = 1;
17         return;
18     }
19     SAM() {
20         init();
21     }
22 
23     inline void insert(char c) {
24         int f = c - 'a';
25         int p = last, np = ++top;
26         last = np;
27         cnt[np] = 1;
28         len[np] = len[p] + 1;
29         while(p && !tr[p][f]) {
30             tr[p][f] = np;
31             p = fail[p];
32         }
33         if(!p) {
34             fail[np] = root;
35         }
36         else {
37             int Q = tr[p][f];
38             if(len[Q] == len[p] + 1) {
39                 fail[np] = Q;
40             }
41             else {
42                 int nQ = ++top;
43                 len[nQ] = len[p] + 1;
44                 fail[nQ] = fail[Q];
45                 fail[Q] = fail[np] = nQ;
46                 memcpy(tr[nQ], tr[Q], sizeof(tr[Q]));
47                 while(tr[p][f] == Q) {
48                     tr[p][f] = nQ;
49                     p = fail[p];
50                 }
51             }
52         }
53         return;
54     }
55     inline void sort() {
56         for(int i = 1; i <= top; i++) {
57             bin[len[i]]++;
58         }
59         for(int i = 1; i <= top; i++) {
60             bin[i] += bin[i - 1];
61         }
62         for(int i = 1; i <= top; i++) {
63             topo[bin[len[i]]--] = i;
64         }
65         return;
66     }
67     inline void cal() {
68         for(int i = top; i; i--) {
69             int x = topo[i];
70             cnt[fail[x]] += cnt[x];
71             if(cnt[x] > 1) {
72                 ans = std::max(ans, 1ll * len[x] * cnt[x]);
73             }
74         }
75         return;
76     }
77 }sam;
78 
79 int main() {
80     scanf("%s", s);
81     n = strlen(s);
82     for(int i = 0; i < n; i++) {
83         sam.insert(s[i]);
84     }
85     sam.sort();
86     sam.cal();
87     printf("%lld", ans);
88 
89     return 0;
90 }
洛谷P3804

hihocoder 1441 

毒瘤...

如何求一个点的right集合?

他的子树中所有在主链上的len( - 1)

这个卡了我好久。。。不能只取叶子,因为iiaaaii,也不能全取,因为ciiaaaii

  1 #include <cstdio>
  2 #include <cstring>
  3 #include <algorithm>
  4 const int N = 100010;
  5 char s[N], p[N];
  6 int n, A[N], num;
  7 
  8 struct SAM {
  9     int tr[N][26], fail[N], len[N], cnt[N], topo[N], bin[N];
 10     int root, top, last;
 11     bool is_new[N];
 12     ///---------------------
 13     struct Edge {
 14         int v, nex;
 15     }edge[N]; int t;
 16     int e[N];
 17     inline void add(int x, int y) {
 18         t++;
 19         edge[t].v = y;
 20         edge[t].nex = e[x];
 21         e[x] = t;
 22         return;
 23     }
 24     ///---------------------
 25     inline void init() {
 26         root = 1;
 27         top = 1;
 28         last = 1;
 29         t = 0;
 30         return;
 31     }
 32     SAM() {
 33         init();
 34     }
 35     inline void insert(char c) {
 36         int f = c - 'a';
 37         int p = last, np = ++top;
 38         last = np;
 39         len[np] = len[p] + 1;
 40         cnt[np] = 1;
 41         is_new[np] = 1;
 42         while(p && !tr[p][f]) {
 43             tr[p][f] = np;
 44             p = fail[p];
 45         }
 46         if(!p) {
 47             fail[np] = root;
 48         }
 49         else {
 50             int Q = tr[p][f];
 51             if(len[Q] == len[p] + 1) {
 52                 fail[np] = Q;
 53             }
 54             else {
 55                 int nQ = ++top;
 56                 fail[nQ] = fail[Q];
 57                 fail[Q] = fail[np] = nQ;
 58                 len[nQ] = len[p] + 1;
 59                 memcpy(tr[nQ], tr[Q], sizeof(tr[Q]));
 60                 while(tr[p][f] == Q) {
 61                     tr[p][f] = nQ;
 62                     p = fail[p];
 63                 }
 64             }
 65         }
 66         return;
 67     }
 68     inline void sort() {
 69         for(int i = 1; i <= top; i++) {
 70             bin[len[i]]++;
 71         }
 72         for(int i = 1; i <= top; i++) {
 73             bin[i] += bin[i - 1];
 74         }
 75         for(int i = 1; i <= top; i++) {
 76             topo[bin[len[i]]--] = i;
 77         }
 78         return;
 79     }
 80     inline void build() {
 81         for(int i = 2; i <= top; i++) {
 82             add(fail[i], i);
 83         }
 84         return;
 85     }
 86     inline void DFS(int x) {
 87         for(int i = e[x]; i; i = edge[i].nex) {
 88             int y = edge[i].v;
 89             DFS(y);
 90         }
 91         if(is_new[x]) {
 92             A[++num] = len[x];
 93         }
 94         return;
 95     }
 96 }sam;
 97 
 98 int main() {
 99     scanf("%s", s);
100     n = strlen(s);
101     for(int i = 0; i < n; i++) {
102         sam.insert(s[i]);
103     }
104     sam.build();
105     int T, m;
106     scanf("%d", &T);
107     while(T--) {
108         scanf("%s", p);
109         m = strlen(p);
110         int k = 1;
111         for(int i = 0; i < m; i++) {
112             int f = p[i] - 'a';
113             k = sam.tr[k][f];
114         }
115         sam.DFS(k);
116         int l = sam.len[sam.fail[k]] + 1;
117         int r = sam.len[k];
118         for(int i = A[1] - l; i < A[1]; i++) {
119             putchar(s[i]);
120         }
121         putchar(' ');
122         for(int i = A[1] - r; i < A[1]; i++) {
123             putchar(s[i]);
124         }
125         putchar(' ');
126         std::sort(A + 1, A + num + 1);
127         for(int i = 1; i <= num; i++) {
128             printf("%d ", A[i]);
129         }
130         puts("");
131         num = 0;
132     }
133 
134     return 0;
135 }
AC代码

hihocoder1445 

求本质不同子串数量。

累加len - len[fail] 即可

 1 #include <cstdio>
 2 #include <cstring>
 3 typedef long long LL;
 4 const int N = 2000010;
 5 char s[N];
 6 int n;
 7 
 8 struct SAM {
 9     int tr[N][26], fail[N], len[N];
10     int root, top, last;
11     inline void init() {
12         root = 1;
13         top = 1;
14         last = 1;
15         return;
16     }
17     SAM() {
18         init();
19     }
20 
21     inline void insert(char c) {
22         int f = c - 'a';
23         int p = last, np = ++top;
24         last = np;
25         len[np] = len[p] + 1;
26         while(p && !tr[p][f]) {
27             tr[p][f] = np;
28             p = fail[p];
29         }
30         if(!p) {
31             fail[np] = root;
32         }
33         else {
34             int Q = tr[p][f];
35             if(len[Q] == len[p] + 1) {
36                 fail[np] = Q;
37             }
38             else {
39                 int nQ = ++top;
40                 fail[nQ] = fail[Q];
41                 fail[Q] = fail[np] = nQ;
42                 len[nQ] = len[p] + 1;
43                 memcpy(tr[nQ], tr[Q], sizeof(tr[Q]));
44                 while(tr[p][f] == Q) {
45                     tr[p][f] = nQ;
46                     p = fail[p];
47                 }
48             }
49         }
50         return;
51     }
52 }sam;
53 
54 int main() {
55     scanf("%s", s);
56     n = strlen(s);
57     for(int i = 0; i < n; i++) {
58         sam.insert(s[i]);
59     }
60     LL ans = 0;
61     for(int i = 2; i <= sam.top; i++) {
62         ans += sam.len[i] - sam.len[sam.fail[i]];
63     }
64     printf("%lld", ans);
65     return 0;
66 }
AC代码

hihocoder1449 

分别求长度为 1 ~ n 的子串中出现最多的那个子串出现的次数。

考虑到SAM的一个点代表多个子串,我们要支持区间取max

我觉得线段树可能有问题,就用了O(n)的算法...

后来问了大佬发现线段树可以滋磁这种操作。

我是直接跟更新len,最后从大到小扫一遍。

 1 #include <cstdio>
 2 #include <cstring>
 3 #include <algorithm>
 4 const int N = 1000010;
 5 char s[N];
 6 int n, large[N];
 7 
 8 struct SAM {
 9     int tr[N << 1][26], fail[N << 1], len[N << 1], cnt[N << 1];
10     int bin[N << 1], topo[N << 1];
11     int root, top, last;
12     inline void init() {
13         top = 1;
14         root = 1;
15         last = 1;
16         return;
17     }
18     SAM() {
19         init();
20     }
21 
22     inline void insert(char c) {
23         int f = c - 'a';
24         int p = last, np = ++top;
25         last = np;
26         len[np] = len[p] + 1;
27         cnt[np] = 1;
28         while(p && !tr[p][f]) {
29             tr[p][f] = np;
30             p = fail[p];
31         }
32         if(!p) {
33             fail[np] = root;
34         }
35         else {
36             int Q = tr[p][f];
37             if(len[Q] ==  len[p] + 1) {
38                 fail[np] = Q;
39             }
40             else {
41                 int nQ = ++top;
42                 len[nQ]  = len[p] + 1;
43                 fail[nQ] = fail[Q];
44                 fail[Q] = fail[np] = nQ;
45                 memcpy(tr[nQ], tr[Q], sizeof(tr[Q]));
46                 while(tr[p][f] == Q) {
47                     tr[p][f] = nQ;
48                     p = fail[p];
49                 }
50             }
51         }
52         return;
53     }
54     inline void sort() {
55         for(int i = 1; i <= top; i++) {
56             bin[len[i]]++;
57         }
58         for(int i = 1; i <= top; i++) {
59             bin[i] += bin[i - 1];
60         }
61         for(int i = 1; i <= top; i++) {
62             topo[bin[len[i]]--] = i;
63         }
64         return;
65     }
66     inline void cal() {
67         for(int i = top; i >= 1; i--) {
68             int x = topo[i];
69             cnt[fail[x]] += cnt[x];
70             large[len[x]] = std::max(large[len[x]], cnt[x]);
71         }
72         return;
73     }
74 }sam;
75 
76 int main() {
77     scanf("%s", s);
78     n = strlen(s);
79     for(int i = 0; i < n; i++) {
80         sam.insert(s[i]);
81     }
82     sam.sort();
83     sam.cal();
84     for(int i = n; i >= 1; i--) {
85         large[i] = std::max(large[i], large[i + 1]);
86     }
87     for(int i = 1; i <= n; i++) {
88         printf("%d\n", large[i]);
89     }
90     return 0;
91 }
AC代码

HAOI2016 找相同字符

求两个字符串的相同子串的个数。

允许本质相同。

首先对第一个子串建出SAM。

第二个串匹配到i时,统计以第二个串第i位结尾的答案即可。

那么第i位,匹配到节点p的答案就是:

(当前匹配长度 - len[fail[p]]) * cnt[p] + p的所有父节点的cnt * △len

后半部分可以预处理出一个ans数组来,拓扑序递推即可。

  1 #include <cstdio>
  2 #include <cstring>
  3 typedef long long LL;
  4 const int N = 400010;
  5 char s[N], s2[N];
  6 int n, m, ans[N];
  7 struct SAM {
  8     int tr[N][26], fail[N], len[N], cnt[N], bin[N], topo[N];
  9     int top, root, last;
 10     inline void init() {
 11         top = 1;
 12         root = 1;
 13         last = 1;
 14         return;
 15     }
 16     SAM() {
 17         init();
 18     }
 19 
 20     inline void insert(char c) {
 21         int f = c - 'a';
 22         int p = last, np = ++top;
 23         last = np;
 24         len[np] = len[p] + 1;
 25         cnt[np] = 1;
 26         while(p && !tr[p][f]) {
 27             tr[p][f] = np;
 28             p = fail[p];
 29         }
 30         if(!p) {
 31             fail[np] = root;
 32         }
 33         else {
 34             int Q = tr[p][f];
 35             if(len[Q] == len[p] + 1) {
 36                 fail[np] = Q;
 37             }
 38             else {
 39                 int nQ = ++top;
 40                 len[nQ] = len[p] + 1;
 41                 fail[nQ] = fail[Q];
 42                 fail[Q] = fail[np] = nQ;
 43                 memcpy(tr[nQ], tr[Q], sizeof(tr[Q]));
 44                 while(tr[p][f] == Q) {
 45                     tr[p][f] = nQ;
 46                     p = fail[p];
 47                 }
 48             }
 49         }
 50         return;
 51     }
 52     inline void sort() {
 53         for(int i = 1; i <= top; i++) {
 54             bin[len[i]]++;
 55         }
 56         for(int i = 1; i <= top; i++) {
 57             bin[i] += bin[i - 1];
 58         }
 59         for(int i = 1; i <= top; i++) {
 60             topo[bin[len[i]]--] = i;
 61         }
 62         return;
 63     }
 64     inline void cal() {
 65         for(int i = top; i; i--) {
 66             int x = topo[i];
 67             cnt[fail[x]] += cnt[x];
 68         }
 69         for(int i = 1; i <= top; i++) {
 70             ans[i] = cnt[i] * (len[i] - len[fail[i]]);
 71         }
 72         for(int i = 1; i <= top; i++) {
 73             int x = topo[i];
 74             ans[x] += ans[fail[x]];
 75         }
 76         return;
 77     }
 78 }sam;
 79 int main() {
 80     scanf("%s", s);
 81     scanf("%s", s2);
 82     n = strlen(s);
 83     m = strlen(s2);
 84     for(int i = 0; i < n; i++) {
 85         sam.insert(s[i]);
 86     }
 87     sam.sort();
 88     sam.cal();
 89     int p = 1, len = 0;
 90     LL t = 0;
 91     for(int i = 0; i < m; i++) {
 92         int f = s2[i] - 'a';
 93         while(p != 1 && !sam.tr[p][f]) {
 94             p = sam.fail[p];
 95             len = sam.len[p];
 96         }
 97         if(sam.tr[p][f]) {
 98             p = sam.tr[p][f];
 99             len++;
100         }
101         if(p != 1) {
102             t += sam.cnt[p] * (len - sam.len[sam.fail[p]]) + ans[sam.fail[p]];
103         }
104     }
105     printf("%lld", t);
106     return 0;
107 } 
AC代码

 

posted @ 2018-07-16 16:59  garage  阅读(225)  评论(0编辑  收藏  举报