2019牛客多校第四场 I题 后缀自动机_后缀数组_求两个串de公共子串的种类数

@(牛客多校第四场 I题 string)

求若干个串的公共子串个数相关变形题

  • 牛客这题题意大概是求一个长度为\(2e5\)的字符串有多少个不同子串,若\(s==t\)\(s==rev(t)\)则认为子串\(s,t\)相同。我们知道回文串肯定和他的反串相同。
  • 链接:传送门

做法1:

  • \(yx\)大佬秒出思路%%,对\(s\)串建后缀自动机,可以得到串\(s\)本质不同的子串的个数\(all\),然后只要能减去有多少个串\(x\)\(rev(x)\)同时也出现了即可。
  • 考虑先求出\(s\)\(rev(s)\)的本质不同的公共子串的数量\(res\),串\(s\)本质不同的回文串数量为\(q\),显然\(res-q\)肯定是\(2\)的倍数。求回文串数量是个板子题:here
  • 因为\(s\)\(rev(s)\)本质不同的公共子串除了回文串,就只有非回文串且\(x=rev(x)\)的串了。又因为\(x\)\(rev(x)\)只能算一次贡献,所以最后答案就是\(all-\frac {res-q} 2\)
  • 所以我们现在只要能求出串\(t\)与串\(s\)的公共子串种类数量即可。(还有一种题是求长度至少为k的公共子串数量

做法2:

广义后缀自动机直接求即可。

用普通后缀自动机也有更简单做法,我在第一个做法下面有讲解。

做法3:

后缀数组


对一个串建后缀自动机,另一个串在上面跑同时计数

  • 构建好\(s\)串的后缀自动机后,从根节点开始用\(t\)串在上面匹配,记录一下已经匹配的\(lcs\)长度\(LEN\)。若\(u\)节点有\(t[i]\)这个后继,则\(u\)跳到\(nex[u][t[i]-'a'],LEN++\);如果没有这个后继,就从\(u\)开始沿着后缀连接树向上走直到碰到一个节点\(x\)\(t[i]\)这个后继或者到了根节点\(x\),则\(u = nex[x][t[i]-'a'],LEN=len[x]+1\)
  • 算贡献就是我当前在\(u\)节点,\(lcs\)长度为\(len\),那么\(LEN-len[link[u]]\)就是符合条件的子串。但是这不完全,就是如果\(len[link[u]]\)也大于\(0\)的话,那么他的父亲状态\(link[u]\)是有符合条件的子串,而且符合条件的子串的数量是固定的:\(len[u]-len[link[u]]\)
  • 听说如果你每次走后缀连接树算完所有贡献的话会\(tle\),一个优化就是匹配结束后,逆拓扑排序更新父亲结点的出现次数。像线段树一样用一个\(lazy\)标记记录它是否需要更新,要记得把\(lazy\)标记向父亲上传。
  • 但是这样不够,因为还有一部分贡献没有计算,你可能多次匹配到自动机上的一个节点,我们需要记录一下匹配到每个节点的最长\(lcs\)长度即\(vis[u]\),若\(vis[u]\)等于\(0\),则贡献如上,反之贡献为\(LEN-vis[u]\),最后更新\(vis[u]\)\(LEN\)
  • 本题结束。

其实还有一个更简单的方法,把串\(s\)和串\(rev(s)\)用一个没有出现过得字符拼接起来,求出新字符串的本质不同的子串个数\(x\),我们知道包含那个未出现过字符的子串数量为\(y = (Len+1)\times (Len+1)\),(注意串\((ba)\)和串\((ab)\)只能计一个贡献)然后在求出\(s\)本质不同的回文串个数\(p\),答案就是\(\frac{x-y+p}2\)

#pragma comment(linker, "/STACK:102400000,102400000")
#include<bits/stdc++.h>
#define fi first
#define se second
#define endl '\n'
#define o2(x) (x)*(x)
#define BASE_MAX 30
#define mk make_pair
#define eb emplace_back
#define all(x) (x).begin(), (x).end()
#define clr(a, b) memset((a),(b),sizeof((a)))
#define iis std::ios::sync_with_stdio(false); cin.tie(0)
#define my_unique(x) sort(all(x)),x.erase(unique(all(x)),x.end())
using namespace std;
#pragma optimize("-O3")
typedef long long LL;
typedef pair<int, int> pii;
inline LL read() {
    LL x = 0;int f = 0;
    char ch = getchar();
    while (ch < '0' || ch > '9') f |= (ch == '-'), ch = getchar();
    while (ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
    return x = f ? -x : x;
}
inline void write(LL x) {
    if (x == 0) {putchar('0'), putchar('\n');return;}
    if (x < 0) {putchar('-');x = -x;}
    static char s[23];
    int l = 0;
    while (x != 0)s[l++] = x % 10 + 48, x /= 10;
    while (l)putchar(s[--l]);
    putchar('\n');
}
int lowbit(int x) { return x & (-x); }
template<class T>T big(const T &a1, const T &a2) { return a1 > a2 ? a1 : a2; }
template<typename T, typename ...R>T big(const T &f, const R &...r) { return big(f, big(r...)); }
template<class T>T sml(const T &a1, const T &a2) { return a1 < a2 ? a1 : a2; }
template<typename T, typename ...R>T sml(const T &f, const R &...r) { return sml(f, sml(r...)); }
void debug_out() { cerr << '\n'; }
template<typename T, typename ...R>void debug_out(const T &f, const R &...r) {cerr << f << " ";debug_out(r...);}
#define debug(...) cerr << "[" << #__VA_ARGS__ << "]: ", debug_out(__VA_ARGS__);

#define print(x) write(x);

const LL INFLL = 0x3f3f3f3f3f3f3f3fLL;
const int HMOD[] = {1000000009, 1004535809};
const LL BASE[] = {1572872831, 1971536491};
const int mod = 998244353;
const int MOD = 1e9 + 7;
const int INF = 0x3f3f3f3f;
const int MXN = 1e6 + 7;

int n;
char s[MXN], t[MXN];
LL all, ANS;
int vis[MXN], lazy[MXN];
struct Palindromic_Tree {
    static const int MAXN = 600005 ;
    static const int CHAR_N = 26 ;
    int next[MAXN][CHAR_N];//next指针,next指针和字典树类似,指向的串为当前串两端加上同一个字符构成
    int fail[MAXN];//fail指针,失配后跳转到fail指针指向的节点
    int cnt[MAXN];
    int num[MAXN];
    int len[MAXN];//len[i]表示节点i表示的回文串的长度
    int S[MAXN];//存放添加的字符
    int last;//指向上一个字符所在的节点,方便下一次add
    int n;//字符数组指针
    int p;//节点指针
    int pos[MAXN];
    int newnode(int l) {//新建节点
        for (int i = 0; i < CHAR_N; ++i) next[p][i] = 0;
        cnt[p] = 0;
        num[p] = 0;
        len[p] = l;
        return p++;
    }
    void init() {//初始化
        p = 0;
        newnode(0);
        newnode(-1);
        last = 0;
        n = 0;
        S[n] = -1;//开头放一个字符集中没有的字符,减少特判
        fail[0] = 1;
    }
    int get_fail(int x) {//和KMP一样,失配后找一个尽量最长的
        while (S[n - len[x] - 1] != S[n]) x = fail[x];
        return x;
    }
    void add(int c, int id) {
        c -= 'a';
        S[++n] = c;
        int cur = get_fail(last);//通过上一个回文串找这个回文串的匹配位置
        if (!next[cur][c]) {//如果这个回文串没有出现过,说明出现了一个新的本质不同的回文串
            int now = newnode(len[cur] + 2);//新建节点
            fail[now] = next[get_fail(fail[cur])][c];//和AC自动机一样建立fail指针,以便失配后跳转
            next[cur][c] = now;
            num[now] = num[fail[now]] + 1;
        }
        last = next[cur][c];
        cnt[last] ++;
        pos[last] = id;
    }
    void count() {
        for (int i = p - 1; i >= 0; --i) cnt[fail[i]] += cnt[i];
        //父亲累加儿子的cnt,因为如果fail[v]=u,则u一定是v的子回文串!
    }
} pt;
struct Suffix_Automaton {
    static const int maxn = 1e6 + 105;
    static const int MAXN = 1e6 + 5;
    //basic
//    map<char,int> nex[maxn * 2];
    int nex[maxn*2][26];
    int link[maxn * 2], len[maxn * 2];
    int last, cnt;
    LL tot_c;//不同串的个数
    //extension
    int cntA[MAXN * 2], A[MAXN * 2];/*辅助拓扑更新*/
    int nums[MAXN * 2];/*每个节点代表的所有串的出现次数*/
    void clear() {
        tot_c = 0;
        last = cnt = 1;
        link[1] = len[1] = 0;
        memset(nex[1], 0, sizeof(nex[1]));
    }
    void init_str(char *s) {
        while (*s) {
            add(*s - 'a');
            ++ s;
        }
    }
    void add(int c) {
        int p = last;
        int np = ++cnt;
//        nex[cnt].clear();
        memset(nex[cnt], 0, sizeof(nex[cnt]));
        len[np] = len[p] + 1;
        last = np;
        while (p && !nex[p][c])nex[p][c] = np, p = link[p];
        if (!p)link[np] = 1, tot_c += len[np] - len[link[np]];
        else {
            int q = nex[p][c];
            if (len[q] == len[p] + 1)link[np] = q, tot_c += len[np] - len[link[np]];
            else {
                int nq = ++cnt;
                len[nq] = len[p] + 1;
//                nex[nq] = nex[q];
                memcpy(nex[nq], nex[q], sizeof(nex[q]));
                link[nq] = link[q];
                link[np] = link[q] = nq;
                tot_c += len[np] - len[link[np]];
                while (nex[p][c] == q)nex[p][c] = nq, p = link[p];
            }
        }
    }
    void build(int n) {
        memset(cntA, 0, sizeof cntA);
        memset(nums, 0, sizeof nums);
        for (int i = 1; i <= cnt; i++)cntA[len[i]]++;
        for (int i = 1; i <= n; i++)cntA[i] += cntA[i - 1];
        for (int i = cnt; i >= 1; i--)A[cntA[len[i]]--] = i;
        /*更行主串节点*/
        int temps = 1;
        for (int i = 0; i < n; i++) {
            nums[temps = nex[temps][s[i] - 'a']] = 1;
        }
        for (int i = cnt, x; i >= 1; i--) {
            x = A[i];
            nums[link[x]] += nums[x];
        }
    }
    void query() {
        int u = 1, LEN = 0;
        for(int i = 0; i < n; ++i) {
            if(nex[u][t[i]-'a']) {
                u = nex[u][t[i]-'a'];
                ++ LEN;
            }else {
                while (u && nex[u][t[i] - 'a'] == 0) u = link[u];
                if (u == 0) u = 1, LEN = 0;
                else {
                    LEN = len[u] + 1;
                    u = nex[u][t[i] - 'a'];
                }
            }
            if(vis[u] == 0) {
                ANS += 1 * (LEN - len[link[u]]);
//                debug(i, t[i], LEN - len[link[u]])
                if (len[link[u]]) lazy[link[u]] = 1;
                vis[u] = LEN;
            }else if(LEN > vis[u]) {
                ANS += 1 * (LEN - vis[u]);
//                debug(i, t[i], LEN - vis[u])
                vis[u] = LEN;
            }
        }
        for(int i = cnt, x; i >= 1; --i) {
            x = A[i];
            if(vis[x] == 0 && len[x] && lazy[x]) {
                ANS += len[x] - len[link[x]];
                vis[x] = len[x];
                if(len[link[x]]) lazy[link[x]] = 1;
            }else if(lazy[x] && vis[x] < len[x]) {
                ANS += len[x] - vis[x];
                vis[x] = len[x];
                if(len[link[x]]) lazy[link[x]] = 1;
            }
            if(len[link[x]]) lazy[link[x]] = 1;
        }
    }
    void DEBUG() {
        for (int i = cnt; i >= 1; i--) {
            printf("nums[%d]=%d numt[%d]=%d len[%d]=%d link[%d]=%d\n", i, nums[i], i, nums[i], i, len[i], i, link[i]);
        }
    }
} sam;

int main() {
#ifndef ONLINE_JUDGE
    freopen("/home/cwolf9/CLionProjects/ccc/in.txt", "r", stdin);
    //freopen("/home/cwolf9/CLionProjects/ccc/out.txt", "w", stdout);
#endif
//    int tim = read();
    scanf("%s", s);
    memcpy(t, s, sizeof(s));
    n = strlen(s);
    reverse(t, t + n);
    sam.clear();
    sam.init_str(s);
    all = sam.tot_c;
    sam.build(n);
    sam.query();
    pt.init();
    for(int i = 0; i < n; ++i) pt.add(s[i], i);
    int hui = pt.p - 2;
    debug(n, hui, all, ANS)
    printf("%lld\n", all - (ANS - hui) / 2);
#ifndef ONLINE_JUDGE
    cout << "time cost:" << clock() << "ms" << endl;
#endif
    return 0;
}

广义后缀自动机

  • 直接离线构建广义后缀自动机(插入函数和普通后缀自动机一模一样),先插入\(s\)串,置\(last=1\),再插入\(rev(s)\),然后对这个后缀自动机求出本质不同的子串个数\(all\)(回文串只计算一次贡献,其他串计算了两次,因为\(x=rev(x)\)),设\(p\)表示\(s\)串本质不同的回文串个数,最后答案即为\(\frac{all+p}2\)

后缀数组


其他:POJ 3415 求两个串长度至少为k的公共子串数量

本题不需要去重。可后缀数组也可后缀自动机写。

后缀自动机
解法和牛客那题基本一样,甚至更简单,因为本题不需要去重,是算总数。
不需要记录每个节点被匹配到的\(lcs\)长度,因此当前节点每次被匹配到的贡献都是\(LEN-max(len[link[u]],k-1)\)
因为是算所有子串的数量,只需要用\(lazy[]\)标记表示这个节点被匹配到的次数,最后逆拓扑序向上传\(lazy[]\)标记即可。

后缀数组
按套路,把\(s,t\)拼成一个串,两遍单调栈,分别算\(t\)串对\(s\)串的贡献和\(s\)串对\(t\)串的贡献

#pragma comment(linker, "/STACK:102400000,102400000")
//#include<bits/stdc++.h>
#include<cstdio>
#include<cstring>
#include<string>
#include<vector>
#include<stack>
#include<map>
#include<iostream>
#include<assert.h>
#define fi first
#define se second
#define endl '\n'
#define o2(x) (x)*(x)
#define BASE_MAX 30
#define mk make_pair
#define eb emplace_back
#define all(x) (x).begin(), (x).end()
#define clr(a, b) memset((a),(b),sizeof((a)))
#define iis std::ios::sync_with_stdio(false); cin.tie(0)
#define my_unique(x) sort(all(x)),x.erase(unique(all(x)),x.end())
using namespace std;
#pragma optimize("-O3")
typedef long long LL;
typedef pair<int, int> pii;
inline LL read() {
    LL x = 0;int f = 0;
    char ch = getchar();
    while (ch < '0' || ch > '9') f |= (ch == '-'), ch = getchar();
    while (ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
    return x = f ? -x : x;
}
inline void write(LL x) {
    if (x == 0) {putchar('0'), putchar('\n');return;}
    if (x < 0) {putchar('-');x = -x;}
    static char s[23];
    int l = 0;
    while (x != 0)s[l++] = x % 10 + 48, x /= 10;
    while (l)putchar(s[--l]);
    putchar('\n');
}
int lowbit(int x) { return x & (-x); }
template<class T>T big(const T &a1, const T &a2) { return a1 > a2 ? a1 : a2; }
//template<typename T, typename ...R>T big(const T &f, const R &...r) { return big(f, big(r...)); }
//template<class T>T sml(const T &a1, const T &a2) { return a1 < a2 ? a1 : a2; }
//template<typename T, typename ...R>T sml(const T &f, const R &...r) { return sml(f, sml(r...)); }
//void debug_out() { cerr << '\n'; }
//template<typename T, typename ...R>void debug_out(const T &f, const R &...r) {cerr << f << " ";debug_out(r...);}
//#define debug(...) cerr << "[" << #__VA_ARGS__ << "]: ", debug_out(__VA_ARGS__);

#define print(x) write(x);

const LL INFLL = 0x3f3f3f3f3f3f3f3fLL;
const int HMOD[] = {1000000009, 1004535809};
const LL BASE[] = {1572872831, 1971536491};
const int mod = 998244353;
const int MOD = 1e9 + 7;
const int INF = 0x3f3f3f3f;
const int MXN = 2e5 + 7;

int n, m, k;
LL ANS;
char s[MXN], t[MXN];
LL lazy[MXN];
struct Suffix_Automaton {
    static const int maxn = 2e5 + 105;
    static const int MAXN = 2e5 + 5;
    //basic
//    map<char,int> nex[maxn * 2];
    int nex[maxn][58];
    int link[maxn * 2], len[maxn * 2];
    int last, cnt;
    LL tot_c;//不同串的个数
    //extension
    int cntA[MAXN * 2], A[MAXN * 2];/*辅助拓扑更新*/
    int nums[MAXN * 2];/*每个节点代表的所有串的出现次数*/
    void clear() {
        tot_c = 0;
        last = cnt = 1;
        link[1] = len[1] = 0;
//        nex[1].clear();
        memset(nex[1], 0, sizeof(nex[1]));
    }
    void init_str(char *s) {
        while (*s) {
            add(*s - 'A');
            ++ s;
        }
    }
    void add(int c) {
        int p = last;
        int np = ++cnt;
//        nex[cnt].clear();
        memset(nex[cnt], 0, sizeof(nex[cnt]));
        len[np] = len[p] + 1;
        last = np;
        while (p && !nex[p][c])nex[p][c] = np, p = link[p];
        if (!p)link[np] = 1, tot_c += len[np] - len[link[np]];
        else {
            int q = nex[p][c];
            if (len[q] == len[p] + 1)link[np] = q, tot_c += len[np] - len[link[np]];
            else {
                int nq = ++cnt;
                len[nq] = len[p] + 1;
//                nex[nq] = nex[q];
                memcpy(nex[nq], nex[q], sizeof(nex[q]));
                link[nq] = link[q];
                link[np] = link[q] = nq;
                tot_c += len[np] - len[link[np]];
                while (nex[p][c] == q)nex[p][c] = nq, p = link[p];
            }
        }
    }
    void build(int n) {
        for(int i = 0; i <= cnt; ++i) nums[i] = cntA[i] = 0;
        for (int i = 1; i <= cnt; i++) cntA[len[i]]++;
        for (int i = 1; i <= n; i++)cntA[i] += cntA[i - 1];
        for (int i = cnt; i >= 1; i--)A[cntA[len[i]]--] = i;
        /*更行主串节点*/
        int temps = 1;
        for (int i = 0; i < n; i++) {
            nums[temps = nex[temps][s[i] - 'A']] = 1;
        }
        for (int i = cnt, x; i >= 1; i--) {
            x = A[i];
            nums[link[x]] += nums[x];
        }
    }
    void query() {
        int u = 1, LEN = 0;
        for(int i = 0; i < m; ++i) {
            if(nex[u][t[i]-'A']) {
                u = nex[u][t[i]-'A'];
                ++ LEN;
            }else {
                while (u && nex[u][t[i] - 'A'] == 0) u = link[u];
                if (u == 0) u = 1, LEN = 0;
                else {
                    LEN = len[u] + 1;
                    u = nex[u][t[i] - 'A'];
                }
            }
            if(LEN >= k) {
                ANS += (LL)nums[u] * (LEN - big(len[link[u]], k - 1));
                if (len[link[u]]) lazy[link[u]] ++;
            }
        }
        for(int i = cnt, x; i >= 1; --i) {
            x = A[i];
            if(len[x] >= k && lazy[x]) {
                ANS += lazy[x] * nums[x] * (len[x] - big(len[link[x]], k - 1));
                if(len[link[x]]) lazy[link[x]] += lazy[x];
            }
        }
    }
    void DEBUG() {
        for (int i = cnt; i >= 1; i--) {
            printf("nums[%d]=%d numt[%d]=%d len[%d]=%d link[%d]=%d\n", i, nums[i], i, nums[i], i, len[i], i, link[i]);
        }
    }
} sam;

int main() {
#ifndef ONLINE_JUDGE
    freopen("/home/cwolf9/CLionProjects/ccc/in.txt", "r", stdin);
    //freopen("/home/cwolf9/CLionProjects/ccc/out.txt", "w", stdout);
#endif
    while(~scanf("%d", &k) && k) {
        scanf("%s%s", s, t);
        n = strlen(s), m = strlen(t);
        sam.clear();
        sam.init_str(s);
        sam.build(n);
        ANS = 0;
        sam.query();
        for(int i = 0; i <= 2 * n + 5; ++i) lazy[i] = 0;
        printf("%lld\n", ANS);
    }
#ifndef ONLINE_JUDGE
    cout << "time cost:" << clock() << "ms" << endl;
#endif
    return 0;
}
posted @ 2019-07-27 22:09 Cwolf9 阅读(...) 评论(...) 编辑 收藏

Contact with me