牛客多校第六场F Palindrome Mouse(回文树求某回文串的所有回文子串)题解

题意:

传送门
给一个串\(str\),现得到他的所有回文子串的集合\(set\),问有多少对\((s,t)\)满足\(s,t\in set\)并且\(s\)\(t\)的一个子串。

思路:

题意就是要求每个本质不同回文串的子串个数的和。建立回文树,那么对于每一个回文树上的节点来说,他的所有回文子串就是他\(fail\)链上所有子节点的\(fail\)和他自己的\(fail\)加上\(next\)链上所有的父节点,当然需要除重,可以用\(vis\)标记一下。

代码:

#include<map>
#include<set>
#include<cmath>
#include<cstdio>
#include<stack>
#include<ctime>
#include<vector>
#include<queue>
#include<cstring>
#include<string>
#include<sstream>
#include<iostream>
#include<algorithm>
typedef long long ll;
using namespace std;
const int maxn = 100000 + 5;
const ll MOD = 998244353;
const int INF = 0x3f3f3f3f;
ll ret, ans;
struct PAM{
    int nex[maxn][26];
    int fail[maxn];
    int len[maxn];
    int str[maxn];
    int cnt[maxn];
    int vis[maxn];
    int last;
    int tot;
    int N;

    int newnode(int L){
        for(int i = 0; i < 26; i++) nex[tot][i] = 0;
        len[tot] = L;
        cnt[tot] = 0;
        vis[tot] = 0;
        return tot++;
    }

    void init(){
        tot = 0;
        newnode(0);
        newnode(-1);
        last = 0;
        N = 0;
        str[0] = -1;
        fail[0] = 1;
    }

    int getfail(int x){
        while(N - len[x] - 1 < 0 || str[N - len[x] - 1] != str[N]) x = fail[x];
        return x;
    }

    void add(char ss){
        int c = ss - 'a';
        str[++N] = c;
        int cur = getfail(last);
        if(!nex[cur][c]){
            int now = newnode(len[cur] + 2);
            fail[now] = nex[getfail(fail[cur])][c];
            nex[cur][c] = now;
        }
        last = nex[cur][c];
        cnt[last]++;
    }

    void dfs(int u){
        stack<int> in;
        while(!in.empty()) in.pop();
        int t = fail[u];
        while(t > 1 && !vis[t]){	//fail链上还没有算进去的子串
            ret++;
            in.push(t);
            vis[t] = 1;
            t = fail[t];
        }
        if(u > 1){
            ans += ret;
            vis[u] = 1;
            ret++;
        }
        for(int i = 0; i < 26; i++)
            if(nex[u][i])
                dfs(nex[u][i]);
        if(u > 1){
            ret--;
            vis[u] = 0;
        }
        t = u;
        while(!in.empty()){
            ret--;
            vis[in.top()] = 0;
            in.pop();
        }
    }

}pam;
char s[maxn];
int main(){
    int T, ca = 1;
    scanf("%d", &T);
    while(T--){
        scanf("%s", s);
        int len = strlen(s);
        pam.init();
        for(int i = 0; i < len; i++) pam.add(s[i]);
        ret = ans = 0;
        pam.dfs(0);
        pam.dfs(1);
        printf("Case #%d: %lld\n", ca++, ans);
    }
    return 0;
}


posted @ 2019-08-07 11:16  KirinSB  阅读(176)  评论(0)    收藏  举报