[bzoj3238] [Ahoi2013] 差异

Description

给定一个长度为 \(n\) 的字符串 \(S\),令 \(Ti\) 表示它从第 \(i\) 个字符开始的后缀。求
\(\sum\limits_{1 \leq i <j \leq n} len(Ti)+len(Tj)-2 \times lcp(Ti,Tj)\)
其中,\(len(a)\) 表示字符串 \(a\) 的长度,\(lcp(a,b)\) 表示字符串 \(a\) 和字符串 \(b\) 的最长公共前缀。

Input

一行,一个字符串 \(S\)

Output

一行,一个整数,表示所求值。

Sample Input

cacao

Sample Output

54

HINT

对于 \(100%\) 的数据,保证 \(2 \leq n \leq 500000\),且均为小写字母。


想法

化简一下要求的那个式子

\[\begin{equation*} \begin{aligned} &\sum\limits_{1 \leq i <j \leq n} len(Ti)+len(Tj)-2 \times lcp(Ti,Tj) \\ =&\frac{n(n+1)(n-1)}{2} -2 \times \sum\limits_{1 \leq i <j \leq n} lcp(Ti,Tj) \end{aligned} \end{equation*} \]

不妨设 \(x= \sum\limits_{1 \leq i <j \leq n} lcp(Ti,Tj)\) ,我们要求的就是它
建出后缀自动机,对于其中每个节点,预处理出它可以到的结束节点的个数 \(size\) ,这个点对 \(x\) 的贡献就是 \(\frac{size(size-1)}{2}\)


代码

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>

using namespace std;

const int N = 500005;
typedef long long ll;

struct node{
    int len,size;
    node *ch[26],*pa;
}pool[N*2],*root,*last;
int cnt;
void insert(int c){
    node *p=last,*cur=&pool[++cnt];
    cur->len=p->len+1;
    for(;p && !p->ch[c];p=p->pa) p->ch[c]=cur;
    if(!p) cur->pa=root;
    else{
        node *q=p->ch[c],*nq;
        if(q->len==p->len+1) cur->pa=q;
        else{
            nq=&pool[++cnt];
            nq->len=p->len+1;
            for(int i=0;i<26;i++) nq->ch[i]=q->ch[i];
            nq->pa=q->pa;
            q->pa=cur->pa=nq;
            for(;p && p->ch[c]==q;p=p->pa) p->ch[c]=nq;
        }
    }
    last=cur;
}

char s[N];
int n;

ll ans,del[N*2];
int vis[N*2];
void Get_size(node *p){
    if(vis[p-pool]) {
        ans-=del[p-pool];
        return;
    }
    ll pre=ans;
    for(int i=0;i<26;i++){
        if(!p->ch[i]) continue;
        Get_size(p->ch[i]);
        p->size+=p->ch[i]->size;
    }
    if(p->size>=2) ans-=(ll)p->size*(p->size-1);
    del[p-pool]=pre-ans; vis[p-pool]=1;
}

int main()
{
    scanf("%s",s);
    n=strlen(s);
    
    root=&pool[++cnt];
    last=root;
    for(int i=0;i<n;i++) insert(s[i]-'a');
    
    node *tmp=last;
    for(;tmp!=root;tmp=tmp->pa) tmp->size=1;
    ans=(ll)n*(n+1)/2*(n-1);
    Get_size(root);
    
    printf("%lld\n",ans+(ll)n*(n-1));
    
    return 0;
}
posted @ 2019-04-13 20:13  秋千旁的蜂蝶~  阅读(98)  评论(0编辑  收藏  举报