[雅礼集训2017 Day1]字符串(SAM+根号分治)
题目:LOJ#6031
题目描述:
记\(f(s,w,l,r)\)表示字符串\(w\)的第\(l\)位到第\(r\)位,即\(w[l...r]\),在字符串\(s\)中出现的次数
给定一个长度为\(n\)的字符串\(s\),给定\(m\)个区间\([l_{i},r_{i}](i∈[0,m])\),\(q\)次询问,每次询问给出一个长度为\(k\)(\(k\)为定值)的字符串\(w\),以及\(a,b\),求\(\sum_{i=a}^{b}f(s,w,l_{i},r_{i})\)
\(0 < n,m,k,q \leq 10^{5}\),\(\sum|w| \leq 10^{5}\),\(0 \leq l,r <m\),\(0 \leq a,b < k\)
蒟蒻题解:
观察数据发现\(\sum|w| \leq 10^{5}\),即\(k \cdot q \leq 10^{5}\),考虑根号分治
字符串\(s\)是固定的,且后面的询问都是问某个子串在\(s\)中出现的次数,不妨对\(s\)建后缀自动机
以下算复杂度时假设\(k \cdot q\),\(m\)和\(n\)是同一级别的
-
当\(q \geq k\)时,询问个数多,但是每个询问的字符串都较短,对于每个询问的字符串\(w\),它的区间个数是\(k^{2}\)级别的,对于每个询问,暴力枚举\(w\)的每个子串的左端点,在\(s\)建出的后缀自动机上跑,找出\(w\)的每个区间在\(s\)的出现次数,复杂度是\(q \cdot k^{2} = n \cdot k \leq n \cdot \sqrt{n}\),对于\(m\)个区间,可以找这\(m\)个区间中满足恰好\(l_{i}=L,r_{i}=R\)(\(L,R\)为\(w\)的一个子串,可以暴力算出其在\(s\)中的出现次数)且满足\(i∈[a,b]\)的区间个数,这个可以提前用动态数组存下每个区间的所有位置,然后二分去找,对于每个区间\(l_{i},r_{i}\),在一次询问中,它只会被二分一次,所以二分的总复杂度是\(q \cdot logn\),总的时间复杂度就是\(\Theta(n \sqrt{n} + q \cdot logn)\)
-
当\(q < k\)时,询问的字符串较长,但是询问个数较少,可以把所有区间\([l_{i},r_{i}]\)记在位置\(r_{i}\)上,复杂度是\(q \cdot m\)的,在后缀自动机上跑时,当前右端点是\(r_{i}\),对应后缀自动机上点\(x\),记录当前匹配的长度\(y\),找左端点\(l_{i}\)可以倍增去找,找\(x\)的深度至少为\(y - r_{i} + l_{i}\)的深度最小的节点,总的时间复杂度是\(\Theta(q \cdot m + q \cdot n \cdot logk)\),但是其实跑后缀自动机很难跑到\(n\),远远跑不满的,如果被卡常的话可以将一部分换成第一种的解法
参考程序:
#include<bits/stdc++.h>
using namespace std;
#define Re register int
typedef long long ll;
const int N = 200005;
const int S = 320;
struct info
{
int l, r;
}a[N];
int n, m, q, k, lst = 1, num = 1, len[N], lk[N], cnt[N], d[N], f[N], ch[N][28], fa[N][18];
ll ans;
char s[N];
vector<int> vt[S][S], qq[N];
inline int read()
{
char c = getchar();
int ans = 0;
while (c < 48 || c > 57) c = getchar();
while (c >= 48 && c <= 57) ans = (ans << 3) + (ans << 1) + (c ^ 48), c = getchar();
return ans;
}
inline void write(ll x)
{
int num = 0;
char sc[25];
if (!x) sc[num = 1] = 48;
while (x) sc[++num] = x % 10 + 48, x /= 10;
while (num) putchar(sc[num--]);
putchar('\n');
}
inline void ins(int x)
{
int y = ++num;
len[y] = len[lst] + 1, f[y] = 1;
while (lst && !ch[lst][x]) ch[lst][x] = y, lst = lk[lst];
if (!lst) lk[y] = 1;
else
{
int u = ch[lst][x];
if (len[u] == len[lst] + 1) lk[y] = u;
else
{
len[++num] = len[lst] + 1, lk[num] = lk[u], lk[u] = lk[y] = num;
memcpy(ch[num], ch[u], sizeof ch[num]);
while (ch[lst][x] == u) ch[lst][x] = num, lst = lk[lst];
}
}
lst = y;
}
inline int find(int x, int y)
{
for (Re i = 16; i >= 0; --i)
if (len[fa[x][i]] >= y) x = fa[x][i];
return x;
}
int main()
{
n = read(), m = read(), q = read(), k = read();
scanf("%s", s + 1);
for (Re i = 1; i <= n; ++i) ins(s[i] - 97);
for (Re i = 1; i <= num; ++i) ++cnt[len[i]];
for (Re i = 1; i <= n; ++i) cnt[i] += cnt[i - 1];
for (Re i = num; i; --i) d[cnt[len[i]]--] = i;
for (Re i = num; i; --i) f[lk[d[i]]] += f[d[i]];
if (q >= k)
{
for (Re i = 0; i < m; ++i)
{
int u = read() + 1, v = read() + 1;
vt[u][v].push_back(i);
}
while (q--)
{
scanf("%s", s + 1), ans = 0;
int u = read(), v = read();
for (Re i = 1; i <= k; ++i)
{
int j = i - 1, t = 1;
while (j < k)
{
++j;
if (!ch[t][s[j] - 97]) break;
t = ch[t][s[j] - 97];
int w = vt[i][j].size();
if (!w) continue;
int l = 0, r = w - 1, ls = w, rs = -1;
while (l <= r)
{
int mid = l + r >> 1;
if (vt[i][j][mid] >= u) ls = mid, r = mid - 1;
else l = mid + 1;
}
if (ls == w) continue;
l = 0, r = w - 1;
while (l <= r)
{
int mid = l + r >> 1;
if (vt[i][j][mid] <= v) rs = mid, l = mid + 1;
else r = mid - 1;
}
if (ls <= rs) ans += 1ll * f[t] * (rs - ls + 1);
}
}
write(ans);
}
}
else
{
for (Re i = 2; i <= num; ++i) fa[i][0] = lk[i];
for (Re i = 1; (1 << i) < n; ++i)
for (Re j = 1; j <= num; ++j) fa[j][i] = fa[fa[j][i - 1]][i - 1];
for (Re i = 0; i < m; ++i) a[i].l = read() + 1, a[i].r = read() + 1;
while (q--)
{
scanf("%s", s + 1), ans = 0;
int u = read(), v = read();
for (Re i = u; i <= v; ++i) qq[a[i].r].push_back(a[i].l);
for (Re i = 1, j = 1, l = 0; i <= k; ++i)
{
int w = s[i] - 97;
if (ch[j][w]) j = ch[j][w], ++l;
else
{
while (j > 1 && !ch[j][w]) j = lk[j];
if (ch[j][w]) l = len[j] + 1, j = ch[j][w];
else l = 0;
}
w = qq[i].size();
if (!w) continue;
if (!l)
{
qq[i].clear();
continue;
}
for (Re t = 0; t < w; ++t)
{
int uu = qq[i][t];
if (i - uu + 1 > l) continue;
ans += f[find(j, i - uu + 1)];
}
qq[i].clear();
}
write(ans);
}
}
return 0;
}

浙公网安备 33010602011771号