# bzoj 4310 跳蚤 二分答案+后缀数组/后缀树

### 做法

bool cmp(int x,int y,int l1,int l2){//s[x..x+l1-1],s[y..y+l2-1]
int tp=lcp(x,y);
if(tp<l1&&tp<l2) return s[x+tp]>s[y+tp];//在比较范围直接比较
return l1>l2; //否则直接比较长度
}

### solution

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cctype>
#include <cmath>
#include <cstdlib>
using namespace std;
typedef long long LL;
const int M=200007;

int n,m,st,len;
char s[M];
int id[M];
int last,tot;
int ch[M][26];
int fa[M],stp[M];
int ed[M];
int dfn[M],pid[M],tdfn;
int pre[M][20],dep[M],Mx;
LL sum[M];

struct edge{int y,nxt;};
struct vec{
int g[M],te;
edge e[M];
vec(){memset(g,0,sizeof(g)); te=0;}
void clear(){memset(g,0,sizeof(g)); te=0;}
inline void push(int x,int y){e[++te].y=y;e[te].nxt=g[x];g[x]=te;}
inline int& operator () (int &x) {return g[x];}
inline edge& operator [] (int &x) {return e[x];}
}go,chr;

int newnode(int ss){
stp[++tot]=ss;
}

int ext(int p,int q,int d){
int nq=newnode(stp[p]+1); ed[nq]=ed[q]-(stp[q]-(stp[p]+1));
fa[nq]=fa[q]; fa[q]=nq;
memcpy(ch[nq],ch[q],sizeof(ch[q]));
for(;p&&ch[p][d]==q;p=fa[p]) ch[p][d]=nq;
return nq;
}

int sam(int p,int d){
int np=ch[p][d];
if(np) return (stp[p]+1==stp[np]) ? np : ext(p,np,d);

np=newnode(stp[p]+1); ed[np]=n;
for(;p&&!ch[p][d];p=fa[p]) ch[p][d]=np;
if(!p) fa[np]=1;
else{
int q=ch[p][d];
fa[np]= (stp[p]+1==stp[q]) ? q : ext(p,q,d);
}
return np;
}

void dfs(int x){
dfn[x]=++tdfn;
pid[tdfn]=x;
sum[tdfn]=stp[x]-stp[fa[x]];
int p,y;
for(p=go(x);p;p=go[p].nxt){
y=go[p].y;
dep[y]=dep[x]+1;
pre[y][0]=x;
dfs(y);
}
}

int LCA(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
for(int t=Mx;t>=0;t--)
if(dep[pre[x][t]]>=dep[y]) x=pre[x][t];
if(x==y) return x;
for(int t=Mx;t>=0;t--)
if(pre[x][t]!=pre[y][t]) x=pre[x][t],y=pre[y][t];
return pre[x][0];
}

int find(LL num){
int l=1,r=tdfn,mid;
while(l<r){
mid=l+r>>1;
if(sum[mid]>=num) r=mid;
else l=mid+1;
}
return l;
}

void getkth(LL num){
int ps=find(num);
int p=pid[ps];
num=sum[ps]-num;
st=ed[p]-stp[p]+1;
len=stp[p]-num;
}

int lcp(int x,int y){
return stp[LCA(id[x],id[y])];
}

bool cmp(int x,int y,int l1,int l2){
int tp=lcp(x,y);
if(tp<l1&&tp<l2) return s[x+tp]>s[y+tp];
return l1>l2;
}

bool check(){
int i,lst=n,blk=0;
for(i=n;i>0;i--){
if(s[i]>s[st]) return 0;
if(cmp(i,st,lst-i+1,len)) blk++,lst=i;
}
return blk+1<=m;
}

int main(){

int i,j,p;

scanf("%d",&m);
scanf("%s",s+1);
n=strlen(s+1);

last=tot=1;
for(i=n;i>0;i--) id[i]=last=sam(last,s[i]-'a');

for(i=2;i<=tot;i++)
chr.push(s[ed[i]-(stp[i]-stp[fa[i]])+1]-'a',i);

for(i=26;i>=0;i--)
for(p=chr(i);p;p=chr[p].nxt)
go.push(fa[chr[p].y],chr[p].y);

dfs(1);
Mx=log2(tot);
for(j=1;j<=Mx;j++)
for(i=1;i<=tot;i++) pre[i][j]=pre[pre[i][j-1]][j-1];
for(i=1;i<=tdfn;i++) sum[i]+=sum[i-1];

LL l=1,r=sum[tdfn],mid;
while(l<r){
mid=l+(r-l)/2;
getkth(mid);
if(check()) r=mid;
else l=mid+1;
}
getkth(l);
for(i=st;i<=st+len-1;i++) printf("%c",s[i]); puts("");
return 0;
}
posted @ 2017-03-20 11:50  _zwl  阅读(132)  评论(0编辑  收藏  举报