CodeForces - 7D Palindrome Degree

最近接触了一点字符串算法,其实也就是一个简单的最大回文串算法,给定字符串s,求出最大字符串长度。

算法是这样的, 用'#'将s字符串中的每个字符分隔,比如s = “aba”,分割后变成#a#b#a#,然后利用下面的算法:

pre:

mx ←0

   for i: = 1 to n-1

        if(mx>i)

          p[i] = min(p[2*id-i], mx-i)

        else p[i] = 1

       while(str[i+p[i]] == str[i-p[i]])

            p[i]++

      if(i+p[i]>mx)

        mx = i+p[i]

         id = i

注意在将s添加'#'之后为了防止越界访问,需要再整个字符串前面加上’$’这样i就是从1开始,p[i]表示在字符中以i为中心的回文串的右半长度,准确的说是r-i+1,r为回文串最右边的字符的下标,

mx表示i之前的位置j的回文串最大右端值,然后每次循环结束的时候更新mx并用id记录i值。

本题求的是字符串s的所有的前缀字符串的k值,k值是这样的定义的,也就是对一个回文串进行二分,分得的两部分仍然是回文串就将s字符串的k值增加1,然后继续分,直到不是回文串。

由于字符串k值取决于自身是不是回文串,所以要先进行判断,然后f[str] = f[substr] + 1,substr表示为str的前半部分的k值,由于substr应该在前面求出了,所以整个过程可以是一个dp过程。

代码:

#include <iostream>
#include <sstream>
#include <cstdio>
#include <climits>
#include <cstring>
#include <cstdlib>
#include <string>
#include <stack>
#include <map>
#include <cmath>
#include <vector>
#include <queue>
#include <algorithm>
#define esp 1e-6
#define pi acos(-1.0)
#define pb push_back
#define lson l, m, rt<<1
#define rson m+1, r, rt<<1|1
#define mp(a, b) make_pair((a), (b))
#define in  freopen("in.txt", "r", stdin);
#define out freopen("out.txt", "w", stdout);
#define print(a) printf("%d\n",(a));
#define bug puts("********))))))");
#define stop  system("pause");
#define Rep(i, c) for(__typeof(c.end()) i = c.begin(); i != c.end(); i++)
#define inf 0x0f0f0f0f

using namespace std;
typedef long long  LL;
typedef vector<int> VI;
typedef pair<int, int> pii;
typedef vector<pii> VII;
typedef vector<pii, int> VIII;
typedef VI:: iterator IT;
const int maxn = 5*1000000+100;
char str[maxn<<1], s[maxn];
int p[maxn<<1];
int ans;
int n;
int f[maxn<<1];
void Init(void)
{
    str[0] = '$', str[1] = '#';
    for(int i = 0; i < n; i++)
    {
        str[i*2+2] = s[i];
        str[i*2+3] = '#';
    }
    int nn = 2*n+2;
    str[nn] = 0;
    int mx = 0, id;
    for(int i = 1; i < nn; i++)
    {
        if(mx > i)
        {
            p[i] = min(p[2*id-i], mx-i);
        }
        else p[i] = 1;
        while(str[i+p[i]] == str[i-p[i]])
            p[i]++;
        if(i + p[i] > mx)
            mx = i+p[i],
            id = i;
    }
}
void solve(void)
{
    LL ans = 0;
    for(int i = 1; i <= n; i++)
    {
        int l = 2, r = 2*i;
        int m = (l+r)>>1;
        if(p[m]*2-1 >= r-l+1)
            f[r] = f[m-1+((m%2) ? 0: -1)]+1;
        ans += f[r];
    }
    printf("%I64d\n", ans);
}

int main(void)
{
    scanf("%s", s);
    n = strlen(s);
    Init();
    solve();
    return 0;
}

更详细的介绍在这里:http://www.cnblogs.com/wuyiqi/archive/2012/06/25/2561063.html

posted on 2013-10-20 11:01  rootial  阅读(809)  评论(1编辑  收藏  举报

导航