BZOJ 3238 [AHOI2013]差异 (后缀数组+单调栈)

题目大意:求$\sum_{1\leq i<j \leq N} suf_{i}+suf_{j}-2\cdot lcp(suf_{i},suf_{j})$

先是后缀数组打错了,又是把+=打成了=,我是zz

我的做法比较奇葩..

转化式子,原式=$\sum_{i=1}^{n-1}(i+1)\cdot i-\sum_{1\leq i<j \leq N}2\cdot lcp(suf_{i},suf_{j})$

这样计算后面的部分就行了

首先用$sa$预处理出$height$数组

对后缀进行排序后,对于某个一个后缀$suf_{i}$,如果另一个后缀$suf_{j}$和它的$lcp$长度是$x$,必须要保证$\forall \;k\in[i+1,j-1],h_{k}\geq x$

用一个单调栈维护$height$,设$num_{tp}$表示栈中$lcp$长度为$L_{tp}$的后缀数量总和

用一个动态的数$sum$记录当前栈中的$h_{k}*num_{k}$总和

每遍历到一个排名$i$,因为 排名$<i$的后缀 和 排名$>i$的后缀 的$lcp$最长是$h_{i}$

故先删去栈中大于$h_{i}$的元素,并记录一共删掉了多少元素

在$i$之后,要去掉对于排名在i之后的后缀的无效长度,所有$h$大于$h_{i}$的部分都要修改成$h_{i}$,即$sum-=(L_{tp}-h_{i})\cdot num_{tp}$

再把$suf_{i}$推入栈中,接着把删掉的元素数量加回来

最终答案就是每次统计完成后的$sum$总和

 1 #include <bitset>
 2 #include <cstdio>
 3 #include <cstring>
 4 #include <algorithm>
 5 #define N1 505000
 6 #define ll long long 
 7 #define inf 0x3f3f3f3f
 8 #define rint register int 
 9 using namespace std;
10 
11 
12 int len;
13 int gch(char *str)
14 {
15     char c=getchar();
16     while(c<'a'||c>'z'){c=getchar();}
17     while(c>='a'&&c<='z'){str[++len]=c;c=getchar();}
18 }
19 int gint()
20 {
21     int ret=0,fh=1;char c=getchar();
22     while(c<'0'||c>'9'){if(c=='-')fh=-1;c=getchar();}
23     while(c>='0'&&c<='9'){ret=ret*10+c-'0';c=getchar();}
24     return ret*fh;
25 }
26 char str[N1];
27 int rk[N1],tr[N1],sa[N1],hs[N1],h[N1];
28 int check(int i,int j,int k){
29     if(i+k>len||j+k>len) return 0;
30     return (rk[i]==rk[j]&&rk[i+k]==rk[j+k])?1:0;}
31 void get_sa()
32 {
33     rint i,cnt=0;
34     for(i=1;i<=len;i++) hs[str[i]]++;
35     for(i=1;i<=128;i++) if(hs[i]) tr[i]=++cnt;
36     for(i=1;i<=128;i++) hs[i]+=hs[i-1];
37     for(i=1;i<=len;i++) rk[i]=tr[str[i]],sa[hs[str[i]]--]=i;
38     for(int k=1;cnt<len;k<<=1)
39     {
40         for(i=1;i<=cnt;i++) hs[i]=0;
41         for(i=1;i<=len;i++) hs[rk[i]]++;
42         for(i=1;i<=cnt;i++) hs[i]+=hs[i-1];
43         for(i=len;i>=1;i--) if(sa[i]>k) tr[sa[i]-k]=hs[rk[sa[i]-k]]--;
44         for(i=1;i<=k;i++) tr[len-i+1]=hs[rk[len-i+1]]--;
45         for(i=1;i<=len;i++) sa[tr[i]]=i;
46         for(i=1,cnt=0;i<=len;i++) tr[sa[i]]=check(sa[i],sa[i-1],k)?cnt:++cnt;
47         for(i=1;i<=len;i++) rk[i]=tr[i];
48     }
49     for(i=1;i<=len;i++){
50         if(rk[i]==1) continue;
51         for(int j=max(1,h[rk[i-1]]-1);;j++)
52             if(str[i+j-1]==str[sa[rk[i]-1]+j-1]) h[rk[i]]=j;
53             else break;
54     }
55 }
56 int stk[N1],num[N1],L[N1],tp;
57 ll sum;
58 ll solve()
59 {
60     ll ans=0,tmp;tp=0;
61     for(int i=2;i<=len;i++)
62     {
63         tmp=0;
64         while(tp>0&&L[tp]>h[i]){
65             tmp+=num[tp];
66             sum-=1ll*(L[tp]-h[i])*num[tp];
67             L[tp]=0,num[tp]=0,tp--;
68         }
69         if(h[i]>L[tp]) 
70             tp++,L[tp]=h[i];
71         num[tp]+=tmp+1;
72         sum+=h[i],ans+=sum;
73     }
74     return ans;
75 }
76 
77 int main()
78 {
79     gch(str);
80     get_sa();
81     ll ans=0;
82     for(int i=1;i<=len-1;i++)
83         ans+=1ll*(i+1)*i;
84     ans=ans/2*3;
85     printf("%lld\n",ans-2ll*solve());
86     return 0;
87 }

 

posted @ 2018-12-08 20:42  guapisolo  阅读(178)  评论(0编辑  收藏  举报