Gym 102994D. String Theory
题目大意
一个长为 \(n(1\le n\le3\cdot10^5,\sum n\le10^6)\) 的串 \(s\) ,有一个 \(k(1\le k\le20)\) 。求 \(s\) 中有多少字串,其可以被分割成连续 \(k\) 个相同的子段。
思路
我们考虑枚举子串分割成 \(k\) 部分的长度 \(len\) ,并且把串 \(s\) 每 \(len\) 个字符分成一块,考虑如果一个合法的子串要么由 \(k\) 个完整的块组成,要么由 \(k-1\) 个完整的块和两边多出来的部分组成,对于第一种情况,我们求原串的 \(sa\) ,不断求以相邻两个块首字符开始的后缀的 \(lcp\) ,如果其长度 \(\ge len\) ,说明可以加入当前若干连续相同的块中,每次检查如果块数 \(>=k\) ,那么答案 \(+1\) 。注意每次计数从 \(1\) 开始。对于第二种情况,我们在第一种情况的过程中,如果快数 \(\ge k-1\) 就进行处理,我们再求一下反串的 \(sa\) ,记以左侧第一个没有算入的块的右端点结尾的前缀与以第一个被算入的块右端点结尾的前缀的最长公共后缀的长度为 \(x\) ,以右侧侧第一个没有算入的块的左端点开始的后缀与以最后一个被算入的块的左端点开始的后缀的 \(lcp\) 长度为 \(y\) ,对 \(x,y\) 都与 \(len-1\) 取 \(min\) 后,\(max(0,x+y-len+1)\) ,就是这一部分的答案,总的复杂度 \(O(nlog^2n)\) 。其他具体细节见写得很丑的代码。
代码
#include<bits/stdc++.h>
#include<unordered_map>
#include<unordered_set>
using namespace std;
using LL = long long;
using LD = long double;
using ULL = unsigned long long;
using PII = pair<int, int>;
using TP = tuple<int, int, int>;
#define all(x) x.begin(),x.end()
#define mst(x,v) memset(x,v,sizeof(x))
#define mk make_pair
//#define int LL
//#define double LD
//#define lc p*2
//#define rc p*2+1
#define endl '\n'
#define inf 0x3f3f3f3f
#define INF 0x3f3f3f3f3f3f3f3f
//#pragma warning(disable : 4996)
#define IOS ios::sync_with_stdio(0),cin.tie(0),cout.tie(0)
const double eps = 1e-10;
const LL MOD = 1000000007;
const LL mod = 998244353;
const int maxn = 300010;
int T, N, M, K;
string S;
int rk[maxn], tmp[maxn], sa[maxn], lcp[maxn];
int ST[maxn][20];
int rk2[maxn], tmp2[maxn], sa2[maxn], lcp2[maxn];
int ST2[maxn][20];
bool compare_sa(int i, int j)
{
if (rk[i] != rk[j])
return rk[i] < rk[j];
else
{
int ri = i + K <= N ? rk[i + K] : -1;
int rj = j + K <= N ? rk[j + K] : -1;
return ri < rj;
}
}
void construct_sa()
{
for (int i = 0; i <= N; i++)
{
sa[i] = i;
rk[i] = i < N ? S[i] : -1;
}
for (K = 1; K <= N; K *= 2)
{
sort(sa, sa + N + 1, compare_sa);
tmp[sa[0]] = 0;
for (int i = 1; i <= N; i++)
tmp[sa[i]] = tmp[sa[i - 1]] + (compare_sa(sa[i - 1], sa[i]) ? 1 : 0);
for (int i = 0; i <= N; i++)
rk[i] = tmp[i];
}
}
void construct_lcp()
{
for (int i = 0; i <= N; i++)
rk[sa[i]] = i;
int h = 0;
lcp[0] = 0;
for (int i = 0; i < N; i++)
{
int j = sa[rk[i] - 1];
if (h > 0)
h--;
for (; j + h < N && i + h < N; h++)
{
if (S[j + h] != S[i + h])
break;
}
lcp[rk[i] - 1] = h;
}
}
void LCP_init(int n)
{
for (int i = 0; i < n; i++)
ST[i][0] = lcp[i];
for (int j = 1; (1 << j) <= n; j++)
{
for (int i = 0; i + (1 << j) - 1 < n; i++)
ST[i][j] = min(ST[i][j - 1], ST[i + (1 << (j - 1))][j - 1]);
}
}
int LCP(int l, int r)//[l,r)
{
if (l >= r)
return 0;
int k = floor(log2(r - l));
return min(ST[l][k], ST[r - (1 << k)][k]);
}
bool compare_sa2(int i, int j)
{
if (rk2[i] != rk2[j])
return rk2[i] < rk2[j];
else
{
int ri = i + K <= N ? rk2[i + K] : -1;
int rj = j + K <= N ? rk2[j + K] : -1;
return ri < rj;
}
}
void construct_sa2()
{
for (int i = 0; i <= N; i++)
{
sa2[i] = i;
rk2[i] = i < N ? S[i] : -1;
}
for (K = 1; K <= N; K *= 2)
{
sort(sa2, sa2 + N + 1, compare_sa2);
tmp2[sa2[0]] = 0;
for (int i = 1; i <= N; i++)
tmp2[sa2[i]] = tmp2[sa2[i - 1]] + (compare_sa2(sa2[i - 1], sa2[i]) ? 1 : 0);
for (int i = 0; i <= N; i++)
rk2[i] = tmp2[i];
}
}
void construct_lcp2()
{
for (int i = 0; i <= N; i++)
rk2[sa2[i]] = i;
int h = 0;
lcp2[0] = 0;
for (int i = 0; i < N; i++)
{
int j = sa2[rk2[i] - 1];
if (h > 0)
h--;
for (; j + h < N && i + h < N; h++)
{
if (S[j + h] != S[i + h])
break;
}
lcp2[rk2[i] - 1] = h;
}
}
void LCP_init2(int n)
{
for (int i = 0; i < n; i++)
ST2[i][0] = lcp2[i];
for (int j = 1; (1 << j) <= n; j++)
{
for (int i = 0; i + (1 << j) - 1 < n; i++)
ST2[i][j] = min(ST2[i][j - 1], ST2[i + (1 << (j - 1))][j - 1]);
}
}
int LCP2(int l, int r)//[l,r)
{
if (l >= r)
return 0;
int k = floor(log2(r - l));
return min(ST2[l][k], ST2[r - (1 << k)][k]);
}
void solve()
{
LL ans = 0;
N = S.length();
if (M == 1)
{
cout << (LL)N * (LL)(N + 1) / 2LL << endl;
return;
}
construct_sa(), construct_lcp(), LCP_init(N + 1);
reverse(all(S));
construct_sa2(), construct_lcp2(), LCP_init2(N + 1);
for (int len = 1; len * M <= N; len++)
{
int cnt = 1;
for (int i = 0; (i + 1) * len < N; i++)
{
int tmp;
if ((i + 1) * len >= N)
tmp = 0;
else
{
int a = min(rk[i * len], rk[(i + 1) * len]), b = max(rk[i * len], rk[(i + 1) * len]);
tmp = LCP(a, b);
}
if (cnt >= M - 1)
{
if (i + 1 - M >= 0)
{
int l = i + 1 - M;
int a = min(rk2[N - (l + 2) * len], rk2[N - (l + 1) * len]), b = max(rk2[N - (l + 2) * len], rk2[N - (l + 1) * len]);
int x = LCP2(a, b), y = tmp;
x = min(len - 1, x), y = min(len - 1, y);
ans += max(0, x + y - len + 1);
}
}
if (tmp >= len)
cnt++;
else
cnt = 1;
if (cnt >= M)
ans++;
}
}
cout << ans << endl;
}
int main()
{
IOS;
cin >> T;
while (T--)
{
cin >> M >> S;
solve();
}
return 0;
}