poj- 3415 Common Substrings 后缀数组+单调栈

题目链接:Common Substrings

题意:给两个串s1,s2,求出长度不小于k的公共子串个数

题解:我们先想一个暴力的解法,先把两个串连到一起中间加一个特殊字符。然后求出sa,和lcp,然后n^2枚举两个子串的开始位置,然后对于每两个子串的公共前缀长度L对

答案的贡献是L-k+1;求和就是答案。但是这个是n^2*log.肯定不行

先把答案分成两部分求,第一部分是对于s1的每一个后缀计算字典序比它小的每一个s2的后缀对答案的贡献,第二部分是对于s2的每一个后缀计算字典序比它小的每一个s1的后缀对答案的贡献.这样不会重复也不会遗漏。前面一样连起来求出sa,和lcp。对于然后我们就需要用单调栈。从底到顶单调递增。为什么要用单调栈呢,因为有这样一个性质

那么对于j和k,不妨设rank[j]<rank[k],则有以下性质:
suffix(j)和suffix(k)的最长公共前缀为height[rank[j]+1],height[rank[j]+2],height[rank[j]+3],…,height[rank[k]]中的最小值。

,这样就可以用一个单调栈来保存一个串之前的所有串的与本串的lcp了。当如果有height值小于了当前栈顶的height值,那么大于它的那些只能按照当前这个小的值来计算

要用两个栈一个维护数量和一个维护的height[]值。计算两次就行了。

//#include<bits/stdc++.h>
#include<iostream>
#include<string>
#include<cstring>
#include<cstdio>
#include<algorithm>
#define pb push_back
#define ll long long
#define PI 3.14159265
#define ls l,m,rt<<1
#define rs m+1,r,rt<<1|1
#define ws wppp
#define eps 1e-7
using namespace std;
const int N=2e5+5;
const int mod=1e9+7;
int s[N];
int sa[N], t[N], t2[N], c[N], n;
int ran[N], lcp[N];
const int inf=0x3fffffff;
void get_sa(int m)
{
    int i, *x = t, *y = t2;
    for(i = 0; i < m; i++) c[i] = 0;
    for(i = 0; i < n; i++) c[x[i] = s[i]]++;
    for(i = 1; i < m; i++) c[i] += c[i-1];
    for(i = n-1; i >= 0; i--) sa[--c[x[i]]] = i;
    for(int k = 1; k <= n; k <<= 1)
    {
        int p = 0;
        for(i = n-k; i < n; i++) y[p++] = i;
        for(i = 0; i < n; i++) if(sa[i] >= k) y[p++] = sa[i] - k;
        for(i = 0; i < m; i++) c[i] = 0;
        for(i = 0; i < n; i++) c[x[y[i]]]++;
        for(i = 0; i< m; i++) c[i] += c[i-1];
        for(i = n-1; i >= 0; i--) sa[--c[x[y[i]]]] = y[i];
        swap(x, y);
        p = 1; x[sa[0]] = 0;
        for(i = 1; i < n; i++)
            x[sa[i]] = y[sa[i-1]] == y[sa[i]] && y[sa[i-1]+k] == y[sa[i]+k] ? p-1 : p++;
        if(p >= n) break;
        m = p;
    }
    int k = 0;
    for(i = 0; i < n; i++) ran[sa[i]] = i;
    for(i = 0; i < n; i++)
    {
        if(k) k--;
        int j = sa[ran[i]-1];
        while(i+k<n&&j+k<n&&s[i+k] == s[j+k]) k++;
         lcp[ran[i]] = k;
    }
}
int m=0;
ll st[N][2];
char s1[N],s2[N];
int main()
{
    while(scanf("%d",&m)&&m)
    {
        scanf("%s",s1);
        scanf("%s",s2);
        int l1=strlen(s1);
        int l2=strlen(s2);
        for(int i=0;i<l1;i++)
        {
            s[i]=s1[i]+1;
        }
        s[l1]=1;
        for(int i=0;i<l2;i++)
        {
            s[l1+i+1]=s2[i]+1;
        }
        n=l1+l2+1;
        s[n]=0;
        n++;
        get_sa(300);
      //  cout<<n<<endl;
        ll ans=0,sum=0;
        int tp=0;
        for(int i=1;i<n;i++)
        {
            if(lcp[i]<m)tp=0,sum=0;
            else
            {
                int num=0;
                while(tp&&lcp[i]<st[tp-1][0])
                {
                    sum+=(lcp[i]-st[tp-1][0])*(st[tp-1][1]);
                    num+=st[tp-1][1];
                    tp--;
                }
                st[tp][0]=lcp[i];
                if(sa[i-1]>l1)
                {
                    sum+=lcp[i]-m+1;
                    st[tp++][1]=num+1;
                }
                else st[tp++][1]=num;
                if(sa[i]<l1)
                {
                    ans+=sum;
                }
            }
        }
        for(int i=1;i<n;i++)
        {
            if(lcp[i]<m)tp=0,sum=0;
            else
            {
                int num=0;
                while(tp&&lcp[i]<st[tp-1][0])
                {
                    sum+=(lcp[i]-st[tp-1][0])*(st[tp-1][1]);
                    num+=st[tp-1][1];
                    tp--;
                }
                st[tp][0]=lcp[i];
                if(sa[i-1]<l1)
                {
                    sum+=lcp[i]-m+1;
                    st[tp++][1]=num+1;
                }
                else st[tp++][1]=num;
                if(sa[i]>l1)
                {
                    ans+=sum;
                }

            }
        }
        printf("%lld\n",ans);
    }
}

 

posted @ 2019-01-28 16:07  lhclqslove  阅读(187)  评论(0)    收藏  举报