算法学习:回文自动机

【定义】

【自动机】

参照AC自动机

 


【前置知识】

【AC自动机】

【manacher】

其实不学这两个也可以,但是学过之后会更方便理解

 


【解决问题】

主要解决回文串的问题

能求出   字符串中回文子串的长度和出现次数

 


 【算法思想】

AC自动机是建立了一个字符串字串的自动机,保存了所有子串之间的信息

信息包括,这个子串作为模式串出现的次数,这个子串的后缀之中能够包含多少模式串

 

而回文自动机则建立了一个字符串中所有回文子串的自动机,保存了所有回文子串之间的信息

信息包括,这个回文子串的长度和出现次数,这个回文子串的后缀之中包含的其他子串

这里的回文子串和位置无关只和构成回文子串的字符有关

同AC自动机一样,需要找到后缀中的信息,自然就需要 f a i l 指针

而这里求取 f a i l 指针时,就不单单是直接插入,而是要在 f a i l 树上不停的跳指针

找到能够增加的符合要求的回文子串的位置,然后扩展,

同时他的fail的建立也需要遵守这个规定

 

谈谈初始化和其他元素

回文自动机的初始化比较重要,因为会对后续产生比较大的影响

(我因为这个调了一天)

 

 

 

 

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<string>
#define ll long long
using namespace std;
const int mod = 19930726;
const int MAXN = 1000010;
ll k;
int n, cnt;
int now;
struct note
{
    int son[26];
    ll len;
    int  siz;
}trie[MAXN];
int fail[MAXN], pos;
bool operator<(note a, note b)
{
    return a.len > b.len;
}
char s[MAXN];
int get(int x)
{
    while (s[pos] != s[pos - trie[x].len - 1]) x = fail[x];
    return x;
}
void insert(int x)
{
    int cur = get(now);
    if (!trie[cur].son[x])
    {
        int u = ++cnt;
        trie[u].len = trie[cur].len + 2;
        fail[u] = trie[get(fail[cur])].son[x];
        trie[cur].son[x] = u;
    }
    trie[trie[cur].son[x]].siz++;
    now = trie[cur].son[x];
}
ll poww(ll a, int b)
{
    ll res = 1;
    while (b)
    {
        if (b & 1)
        {
            res = (res*a) % mod;
        }
        a = (a*a) % mod;
        b = b >> 1;
    }
    return res;
}
void init()
{
    fail[0] = 1; fail[1] = 0;
    now = 0, cnt = 1;
    trie[1].len = -1;
}
int main()
{
    scanf("%d%lld", &n, &k);
    scanf("%s", s+1);
    init();
    s[0] = 0;
    for (pos = 1; pos <= n; pos++)
        insert(s[pos] - 'a');

    for (int i = cnt; i >= 2; i--)
        trie[fail[i]].siz = (trie[fail[i]].siz + trie[i].siz) % mod;
    //for (int i = 1; i <= cnt; i++)
        //printf("%lld %d\n", trie[i].len, trie[i].siz);
    sort(trie + 1, trie + 1 + cnt);
    ll ans = 1;
    int pos = 1;
    while (k)
    {
        if (pos > cnt)
        {
            printf("-1");
            return 0;
        }
        if (trie[pos].len % 2 == 0)
        {
            pos++;
            continue;
        }
        ans = (ans*poww(trie[pos].len, k < trie[pos].siz ? k : trie[pos].siz)) % mod;
        k -= k < trie[pos].siz ? k : trie[pos].siz;
        pos++;
    }
    printf("%lld", ans);
    return 0;
}
View Code

 

posted @ 2019-08-09 17:01  rentu  阅读(149)  评论(0编辑  收藏  举报