2019牛客全国多校训练四 I题 string (SAM+PAM)

链接:https://ac.nowcoder.com/acm/contest/884/I
来源:牛客网

题目描述

We call a,ba,ba,b non-equivalent if and only if a≠ba \neq ba=b and a≠rev(b)a \neq rev(b)a=rev(b), where rev(s)rev(s)rev(s) refers to the string obtained by reversing characters of sss, for example rev(abca)=acbarev(abca)=acbarev(abca)=acba.
There is a string sss consisted of lower-case letters. You need to find some substrings of sss so that any two of them are non-equivalent. Find out what's the largest number of substrings you can choose.

输入描述:

A line containing a string sss of lower-case letters.

输出描述:

A positive integer - the largest possible number of substrings of sss that are non-equivalent.
示例1

输入

abac

输出

8

说明

The set of following substrings is such a choice: abac,b,a,ab,aba,bac,ac,cabac,b,a,ab,aba,bac,ac,cabac,b,a,ab,aba,bac,ac,c.

备注:

1≤∣s∣≤2×1051 \leq |s|\leq 2 \times 10^51s2×105, sss is consisted of lower-case letters.

题解:

题目给你一个字符串s,让你求s中的子串组成的最大集合,满足这个集合内的每一个子串str,  str和rev(str)不同时存在{rev(str):表示str反过来}

思路:就是先用SAM统计出s#rev(s)中不包含 '#'的所有子串ans1; 然后用PAM统计出s中本质不同的子串数量ans2;

这答案就是(ans1+ans2)/2;

为什么呢?

因为在用SAM统计s#rev(s)的时候会把所有字符串统计两边,而本身就是回文串的只会统计一遍。

参考代码:

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 typedef long long ll;
  4 #define pii pair<int,int>
  5 #define pil pair<int,long long>
  6 const int INF=0x3f3f3f3f;
  7 const ll inf=0x3f3f3f3f3f3f3f3fll;
  8 inline int read()
  9 {
 10     int x=0,f=1;char ch=getchar();
 11     while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
 12     while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}    
 13     return x*f;
 14 }
 15 const int maxn=4e5+10;
 16 const int MAXN=4e5+10; 
 17 char str[maxn];
 18 int s[maxn];
 19 ll ans;
 20 struct SAM{
 21     int l[maxn<<1],fa[maxn<<1],nxt[maxn<<1][30];
 22     int last,cnt;
 23     
 24     void Init()
 25     {
 26         ans=0;last=cnt=1;
 27         l[cnt]=fa[cnt]=0;
 28         memset(nxt[cnt],0,sizeof(nxt[cnt]));        
 29     }
 30     
 31     int NewNode()
 32     {
 33         ++cnt;
 34         memset(nxt[cnt],0,sizeof(nxt[cnt]));
 35         l[cnt]=fa[cnt]=0;
 36         return cnt;    
 37     }
 38     
 39     void Insert(int ch)
 40     {
 41         int np=NewNode(),p=last;
 42         last=np; l[np]=l[p]+1;
 43         while(p&&!nxt[p][ch]) nxt[p][ch]=np,p=fa[p];
 44         if(!p) fa[np]=1;
 45         else
 46         {
 47             int q=nxt[p][ch];
 48             if(l[p]+1==l[q]) fa[np]=q;
 49             else
 50             {
 51                 int nq=NewNode();
 52                 memcpy(nxt[nq],nxt[q],sizeof(nxt[q]));
 53                 fa[nq]=fa[q];
 54                 l[nq]=l[p]+1;
 55                 fa[np]=fa[q]=nq;
 56                 while(nxt[p][ch]==q) nxt[p][ch]=nq,p=fa[p];
 57             }    
 58         }    
 59         ans+=1ll*(l[last]-l[fa[last]]);    
 60     }
 61     
 62 }sam;
 63 
 64 struct Palindromic_Tree{
 65     int next[MAXN][26];
 66     int fail[MAXN];
 67     int cnt[MAXN];
 68     int num[MAXN];
 69     int len[MAXN];
 70     int S[MAXN];
 71     int last;
 72     int n;
 73     int p;
 74  
 75     int newnode(int l) 
 76     {
 77         for(int i=0;i<26;++i) next[p][i]=0;
 78         cnt[p]=0;
 79         num[p]=0;
 80         len[p]=l;
 81         return p++;
 82     }
 83  
 84     void Init() 
 85     {
 86         p=0;
 87         newnode( 0);
 88         newnode(-1);
 89         last=0;
 90         n=0;
 91         S[n]=-1;
 92         fail[0]=1;
 93     }
 94  
 95     int get_fail(int x)
 96     {
 97         while(S[n-len[x]-1]!=S[n])x=fail[x] ;
 98         return x ;
 99     }
100  
101     void add(int c) 
102     {
103         S[++ n]=c;
104         int cur=get_fail(last) ;
105         if(!next[cur][c]) 
106         {
107             int now=newnode(len[cur]+2) ;
108             fail[now]=next[get_fail(fail[cur])][c] ;
109             next[cur][c]=now ;
110             num[now]=num[fail[now]]+1;
111         }
112         last=next[cur][c];
113         cnt[last]++;
114     }
115  
116     ll count() 
117     {
118         ll res=p*1ll;
119         for(int i=p-1;i>=0;--i) cnt[fail[i]]+=cnt[i];
120         //for(int i=1;i<=p;++i) res+=cnt[i];
121         //cout<<"res "<<res<<endl;
122         return (res-2);
123     }
124 } pam;
125 
126 int main()
127 {
128     scanf("%s",str);
129     int len=strlen(str);
130     
131     sam.Init();
132     for(int i=0;i<len;++i) sam.Insert(str[i]-'a');
133     sam.Insert(28);
134     for(int i=len-1;i>=0;--i) sam.Insert(str[i]-'a');
135     ans-=1ll*(len+1)*(len+1);
136     //cout<<"ans "<<ans<<endl;
137     pam.Init();
138     for(int i=0;i<len;++i) pam.add(str[i]-'a');
139     ans=ans+pam.count();
140     
141     printf("%lld\n",(ans/2ll));
142     
143     
144     return 0;    
145 }
View Code

 

 

 
posted @ 2019-07-28 09:36  StarHai  阅读(415)  评论(0编辑  收藏  举报