poj 3415 后缀数组分组+排序+并查集

Source Code

Problem: 3415   User: wangyucheng
Memory: 16492K   Time: 704MS
Language: C++   Result: Accepted
    • Source Code
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
#define N 510000
typedef long long ll;
int wa[N],wb[N],sa[N],wv[N],ss[N],a[N];
int n;
int a1,a2;
int cmp(int *r,int x,int y,int k){
   return r[x]==r[y]&&r[x+k]==r[y+k];	
}
structP{
   int x,y,z;
   P(int a=0,int b=0){
	   x=a,y=b;	
	}	
	bool operator<(P a)const{
	   return x>a.x;
	}
}b[N];
int b1;
void da(int *r,int m){
    int p,i,j,*x=wa,*y=wb;
	for(i=0;i<n;i++)r[i]++;
	r[n++]=0;
	for(i=0;i<m;i++)ss[i]=0;
	for(i=0;i<n;i++)ss[x[i]=r[i]]++;
	for(i=1;i<m;i++)ss[i]+=ss[i-1];
	for(i=n-1;i>=0;i--)sa[--ss[x[i]]]=i;
	for(p=0,j=1;p<n;j<<=1,m=p){
	   for(p=0,i=n-j;i<n;i++)y[p++]=i;
	   for(i=0;i<n;i++)if(sa[i]>=j)y[p++]=sa[i]-j;
	   for(i=0;i<m;i++)ss[i]=0;
	   for(i=0;i<n;i++)wv[i]=x[y[i]];
	   for(i=0;i<n;i++)ss[wv[i]]++;
	   for(i=1;i<m;i++)ss[i]+=ss[i-1];
	   for(i=n-1;i>=0;i--)sa[--ss[wv[i]]]=y[i];
	   	for(swap(x,y),x[sa[0]]=0,p=1,i=1;i<n;i++)
		x[sa[i]]=cmp(y,sa[i],sa[i-1],j)?p-1:p++;
	}	
}
int rank[N],he[N];
void ma(int *r){
	int i,k=0;
	for(i=0;i<n;i++)rank[sa[i]]=i;
	for(i=0;i<n-1;i++){
	 	for(k?k--:0;r[sa[rank[i]-1]+k]==r[i+k];k++);
		he[rank[i]]=k;
	}
}
char in[N];
int K;
int f[N];
int get(int x){
	return f[x]==x?x:f[x]=get(f[x]);
}
ll ans;
ll s[N][2];
void he1(int x,int y,ll &z){
	int c=get(x);
	int d=get(y);
	z+=s[c][0]*s[d][1]+s[d][0]*s[c][1];
	f[c]=d;
	s[d][0]+=s[c][0];
	s[d][1]+=s[c][1];
}
void solv(int l,int r){
	 int i,j,y;
	if(a[sa[l]]-1=='#')return;
	b1=0;
	for(i=l;i<=r;i++){
		s[i][0]=s[i][1]=0;
	    if(sa[i]<a1)y=1;
		else y=0;
		s[i][y]++;
		if(i==l)b[++b1]=P(n+1,i);
		else b[++b1]=P(he[i]-K+1,i);
	}
	sort(b+1,b+b1+1);
	for(i=l;i<=r;i++)f[i]=i;
	int la;
	ll tot=0;
	la=0;
	b[b1+1].x=0;
	for(i=2;i<=b1+1;i++){
	    if(i==b1+1||b[i].x!=b[i-1].x){
		   for(j=la+1;j<i;j++){
			   if(b[j].y>l)he1(b[j].y-1,b[j].y,tot);
			}
		   la=i-1;
		   ans+=tot*(ll)(b[i-1].x-b[i].x);
		}
	}
}

int main(){
	while(scanf("%d",&K),K){
		ans=0;
	   scanf("%s",in);
	   a1=strlen(in);
	   int i,j;
	   for(i=0;i<a1;i++)a[i]=in[i];
	   a[a1]='#';	
	   scanf("%s",in);
	   a2=strlen(in);
	   n=a1+a2+1;
	   for(i=a1+1;i<n;i++)a[i]=in[i-a1-1];
	   da(a,300);
	   ma(a);
		int la=0;
		for(i=1;i<=n;i++){
		    if(he[i]<K||i==n){
			    solv(la,i-1);
				la=i;	
			}	
		}
		printf("%lld\n",ans);
	}
	
	
}
posted @ 2014-06-13 18:05  wangyucheng  阅读(267)  评论(0编辑  收藏  举报