树上后缀数组

 

树上后缀数组模板题:

https://www.codechef.com/problems/DIFTRIP

//#pragma GCC optimize("Ofast,no-stack-protector,unroll-loops,fast-math")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4.1,sse4.2,avx,avx2,popcnt,tune=native")

//#include <immintrin.h>
//#include <emmintrin.h>
#include <bits/stdc++.h>
using namespace std;
#define rep(i,h,t) for (int i=h;i<=t;i++)
#define dep(i,t,h) for (int i=t;i>=h;i--)
#define ll long long
#define me(x) memset(x,0,sizeof(x))
#define IL inline
#define rint register int
inline ll rd(){
    ll x=0;char c=getchar();bool f=0;
    while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    return f?-x:x;
}
char ss[1<<24],*A=ss,*B=ss;
IL char gc()
{
    return A==B&&(B=(A=ss)+fread(ss,1,1<<24,stdin),A==B)?EOF:*A++;
}
template<class T>void maxa(T &x,T y)
{
    if (y>x) x=y;
}
template<class T>void mina(T &x,T y)
{
    if (y<x) x=y;
}
template<class T>void read(T &x)
{
    int f=1,c; while (c=gc(),c<48||c>57) if (c=='-') f=-1; x=(c^48);
    while(c=gc(),c>47&&c<58) x=x*10+(c^48); x*=f;
}
const int mo=1e9+7;
ll fsp(int x,int y)
{
    if (y==1) return x;
    ll ans=fsp(x,y/2);
    ans=ans*ans%mo;
    if (y%2==1) ans=ans*x%mo;
    return ans;
}
struct cp {
    ll x,y;
    cp operator +(cp B)
    {
        return (cp){x+B.x,y+B.y};
    }
    cp operator -(cp B)
    {
        return (cp){x-B.x,y-B.y};
    }
    ll operator *(cp B)
    {
        return x*B.y-y*B.x;
    }
    int half() { return y < 0 || (y == 0 && x < 0); }
};
struct re{
    int a,b,c;
};
#define ull unsigned long long
const int N=2e5;
const ull base=23333;
vector<int> ve[N];
ull ba[N],gg[20][N];
int bz[20][N],dep[N],a[N];
int x[N],y[N],sa[N],h[N],rk[N],c[N],xx[N],n;
void dfs(int x,int y)
{
    dep[x]=dep[y]+1; 
    bz[0][x]=y; gg[0][x]=a[x];
    rep(i,1,19) bz[i][x]=bz[i-1][bz[i-1][x]];
    rep(i,1,19) gg[i][x]=gg[i-1][x]+gg[i-1][bz[i-1][x]]*ba[(1<<(i-1))];
    for (auto v:ve[x])
      if (v!=y)
      {
          dfs(v,x);
      }
}
int lcp(int x,int y)
{
    int xx=x,yy=y;
    dep(i,19,0)
      if (gg[i][xx]==gg[i][yy])
        xx=bz[i][xx],yy=bz[i][yy];
    return min(dep[x]-dep[xx],dep[y]-dep[yy]); //会有跳到根的情况 
}
void asa(int n)
{
    int p=0;
    rep(i,1,n) c[i]=0;
    rep(i,1,n) c[x[i]=a[i]]++;
    rep(i,1,n) c[i]+=c[i-1];
    dep(i,n,1) sa[c[x[i]]--]=i;
    for (int i=1,k=0;i<=n;i<<=1,k++)
    {
        rep(j,1,n) xx[j]=x[bz[k][j]];
        rep(j,1,n) c[j]=0;
        rep(j,1,n) c[xx[j]]++;
        rep(j,1,n) c[j]+=c[j-1];
        dep(j,n,1) y[c[xx[j]]--]=j;
        
        // 这里处理也要不同,因为按原先处理会有一些y相同 
        rep(j,1,n) c[j]=0;
        rep(j,1,n) c[x[y[j]]]++;
        rep(j,1,n) c[j]+=c[j-1];
        dep(j,n,1) sa[c[x[y[j]]]--]=y[j];
        swap(x,y); x[sa[1]]=1; p=2;
        rep(j,2,n)
          x[sa[j]]=y[sa[j]]==y[sa[j-1]]&&y[bz[k][sa[j]]]==y[bz[k][sa[j-1]]]?p-1:p++;
    }
    rep(i,1,n) rk[sa[i]]=i;
    rep(i,1,n)
    {
        h[i]=lcp(sa[i-1],sa[i]); //因为不满足h[rk[i]>=h[rk[i-1]]-1 所以只能倍增hash计算 
    }
}
int main()
{
   freopen("1.in","r",stdin);
   freopen("1.out","w",stdout);
   ios::sync_with_stdio(false);
   cin>>n;
   rep(i,1,n-1)
   {
        int x,y;
        cin>>x>>y;
        ve[x].push_back(y); ve[y].push_back(x);
   }
   rep(i,1,n) a[i]=ve[i].size();
   ba[0]=1;
   rep(i,1,n) ba[i]=ba[i-1]*base;
   dfs(1,0);
   asa(n);
   ll ans=0;
       for(int i=1;i<=n;i++) 
        ans+=dep[sa[i]]-h[i];
    cout<<ans<<endl; 
   return 0;
}
View Code

 

posted @ 2021-06-07 18:06  尹吴潇  阅读(70)  评论(0编辑  收藏  举报