# 【题解】Asterisk Substrings Codeforces 1276F 后缀自动机 树链的并

### star在中间的串

sam的后缀树上跑DSU on Tree，维护上述end_pos集合，并时刻维护集合中所有点的树链的并

#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
const int N = 200010;
int _w;

struct SAM {
int ch[N][26];
int len[N];
int pa[N];
int idx;

void init() {
memset(ch, 0, sizeof ch);
memset(len, 0, sizeof len);
memset(pa, 0, sizeof pa);
idx = 1;
pa[0] = -1;
}
int append( int p, int c ) {
int np = idx++;
len[np] = len[p] + 1;
while( p != -1 && !ch[p][c] )
ch[p][c] = np, p = pa[p];
if( p == -1 ) pa[np] = 0;
else {
int q = ch[p][c];
if( len[q] == len[p] + 1 ) pa[np] = q;
else {
int nq = idx++;
memcpy(ch[nq], ch[q], sizeof ch[nq]);
len[nq] = len[p] + 1;
pa[nq] = pa[q];
pa[q] = pa[np] = nq;
while( p != -1 && ch[p][c] == q )
ch[p][c] = nq, p = pa[p];
}
}
return np;
}
};

int n;
char str[N];
SAM sam, rsam;

ll solve_origin() {
sam.init();
int last = 0;
for( int i = 1; i <= n; ++i )
last = sam.append(last, str[i] - 'a');
ll ans = 0;
for( int i = 1; i < sam.idx; ++i )
ans += sam.len[i] - sam.len[sam.pa[i]];
return ans;
}

ll solve_before() {
sam.init();
int last = 0;
for( int i = 2; i <= n; ++i )
last = sam.append(last, str[i] - 'a');
ll ans = 0;
for( int i = 1; i < sam.idx; ++i )
ans += sam.len[i] - sam.len[sam.pa[i]];
return ans;
}

ll solve_after() {
sam.init();
int last = 0;
for( int i = 1; i <= n-1; ++i )
last = sam.append(last, str[i] - 'a');
ll ans = 0;
for( int i = 1; i < sam.idx; ++i )
ans += sam.len[i] - sam.len[sam.pa[i]];
return ans;
}

struct Graph {
int head[N], nxt[N], to[N], eid;
void init() {
eid = 0;
}
void link( int u, int v ) {
to[eid] = v, nxt[eid] = head[u], head[u] = eid++;
}
};
Graph g, rg;

namespace HLD {
int dfn[N], dfnc, top[N], dep[N];
int pa[N], sz[N], son[N], val[N];
int rdfn[N];

void dfs1( int u, int fa, int d ) {
sz[u] = 1, dep[u] = d, pa[u] = fa;
val[u] = rsam.len[u];
for( int i = rg.head[u]; ~i; i = rg.nxt[i] ) {
int v = rg.to[i];
dfs1(v, u, d+1);
sz[u] += sz[v];
if( son[u] == -1 || sz[v] > sz[son[u]] )
son[u] = v;
}
}
void dfs2( int u, int tp ) {
dfn[u] = ++dfnc, top[u] = tp;
rdfn[dfnc] = u;
if( son[u] != -1 )
dfs2( son[u], tp );
for( int i = rg.head[u]; ~i; i = rg.nxt[i] ) {
int v = rg.to[i];
if( v != son[u] )
dfs2(v, v);
}
}
void init() {
memset(son, -1, sizeof son);
dfs1(0, -1, 1);
dfs2(0, 0);
}
int lca( int u, int v ) {
while( top[u] != top[v] ) {
if( dep[top[u]] < dep[top[v]] )
swap(u, v);
u = pa[top[u]];
}
return dep[u] < dep[v] ? u : v;
}
}

int mark[N], rmark[N], rmark2nod[N];
ll solve_ans = 0, now = 0;
set<int> st;

void ins_node( int u ) {
u = mark[u];
if( !u ) return;
u = rmark2nod[u+2];
u = HLD::dfn[u];
if( st.empty() ) {
st.insert(u);
u = HLD::rdfn[u];
now += HLD::val[u];
} else {
auto after = st.lower_bound(u);
auto before = after;
--before;
if( after == st.end() ) {
int L = *before;
L = HLD::rdfn[L];
u = HLD::rdfn[u];
int lca = HLD::lca(L, u);
now -= HLD::val[lca];
now += HLD::val[u];
u = HLD::dfn[u];
st.insert(u);
} else if( after == st.begin() ) {
int R = *after;
R = HLD::rdfn[R];
u = HLD::rdfn[u];
int lca = HLD::lca(R, u);
now -= HLD::val[lca];
now += HLD::val[u];
u = HLD::dfn[u];
st.insert(u);
} else {
int L = *before;
int R = *after;
L = HLD::rdfn[L];
R = HLD::rdfn[R];
now += HLD::val[HLD::lca(L, R)];
u = HLD::rdfn[u];
now -= HLD::val[HLD::lca(L, u)];
now -= HLD::val[HLD::lca(R, u)];
now += HLD::val[u];
u = HLD::dfn[u];
st.insert(u);
}
}
}

void ins_tree( int u ) {
ins_node(u);
for( int i = g.head[u]; ~i; i = g.nxt[i] )
ins_tree( g.to[i] );
}

int sz[N], son[N];

void init_sack( int u ) {
sz[u] = 1, son[u] = -1;
for( int i = g.head[u]; ~i; i = g.nxt[i] ) {
int v = g.to[i];
init_sack(v);
sz[u] += sz[v];
if( son[u] == -1 || sz[v] > sz[son[u]] )
son[u] = v;
}
}

void sack( int u, bool clr ) {
// printf( "u = %d\n", u );
for( int i = g.head[u]; ~i; i = g.nxt[i] )
if( g.to[i] != son[u] )
sack( g.to[i], true );
if( son[u] != -1 )
sack( son[u], false );
for( int i = g.head[u]; ~i; i = g.nxt[i] )
if( g.to[i] != son[u] )
ins_tree( g.to[i] );
ins_node(u);
// printf( "u = %d, now = %lld\n", u, now );
if( u )
solve_ans += 1LL * now * (sam.len[u] - sam.len[sam.pa[u]]);
if( clr ) st.clear(), now = 0;
}

ll solve() {
sam.init();
int last = 0;
for( int i = 1; i <= n-2; ++i )
last = sam.append(last, str[i] - 'a');
g.init();
for( int i = 1; i < sam.idx; ++i )
g.link( sam.pa[i], i );
last = 0;
for( int i = 1; i <= n-2; ++i ) {
last = sam.ch[last][str[i] - 'a'];
mark[last] = i;
}

rsam.init();
last = 0;
for( int i = n; i >= 3; --i )
last = rsam.append(last, str[i] - 'a');
rg.init();
for( int i = 1; i < rsam.idx; ++i )
rg.link( rsam.pa[i], i );
last = 0;
for( int i = n; i >= 3; --i ) {
last = rsam.ch[last][str[i] - 'a'];
rmark[last] = i;
rmark2nod[i] = last;
}

HLD::init();
init_sack(0);
sack(0, false);
return solve_ans;
}

int main() {
_w = scanf( "%s", str+1 );
n = (int)strlen(str+1);
ll ans = 0;
ans += solve_origin();
// printf( "after origin = %lld\n", ans );
if( n >= 2 ) {
ans += solve_before();
ans += solve_after();
}
// printf( "before after = %lld\n", ans );
if( n >= 3 ) {
ans += solve();
}
printf( "%lld\n", ans+2 );
return 0;
}

posted @ 2020-02-16 01:19  mlystdcall  阅读(330)  评论(0编辑  收藏  举报