题解 CF862F Mahmoud and Ehab and the final stage
首先令 \(a_i=\operatorname{lcp}(s_i,s_{i+1})\),原问题转化成求 \([l,r]\) 内 \(\min\limits_{L\le i< R}\{s_i\}\times(R-L+1)\) 的最大值,单点修改。对于 \(L=R\) 的部分就是求最大长度,显然线段树可以解决。剩下的就是求 \(\min\limits_{L\le i\le R}\{s_i\}\times(R-L+2)\) 的最大值,\(l\le L\le R<r\)。
注意到一个很重要的性质:字符串长度总和不超过 \(10^5\),即 \(\sum a_i\le 10^5\)。所以考虑根号分治。设定阈值 \(T\),把答案按照区间最小值的大小分成两类。
-
区间最小值 \(\le T\)。假设已经知道最小值为 \(x\),我们只关心某个数是否大于等于 \(x\)。那么把大于等于 \(x\) 的数标为 \(1\),剩下的标为 \(0\),只需要求出区间 \([l,r-1]\) 内的最长的只包含 \(1\) 的连续段。用 \(T\) 棵线段树维护即可。
-
区间最小值 \(>T\)。这样的位置最多 \(\dfrac{10^5}{T}\) 个。假设其中存在一个连续区间 \([p,q]\) 并且是 \([L,R]\) 的子区间,对于其中的每一个 \(i\),用单调栈求出 \(lp_i,rp_i\),分别表示 \(i\) 往左 / 往右第一个小于 \(a_i\) 的位置。那么以 \(a_i\) 为最小值的区间的答案就是 \(a_i\times(rp_i-lp_i)\)。
总时间复杂度 \(O(q(\dfrac{\sum|s_i|}{T}+T\log n))\),取 \(T=\sqrt{\dfrac{\sum|s_i|}{\log n}}\) 时最优,为 \(O(q\sqrt{(\sum|s_i|)\log n})\)。
#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <string>
using namespace std;
typedef long long ll;
typedef pair <int, int> pii;
#define ci const int
inline int max(ci &x, ci &y) { return x > y ? x : y; }
inline int min(ci &x, ci &y) { return x < y ? x : y; }
inline void swap(int &x, int &y) { x ^= y ^= x ^= y; }
inline void chmax(ll &x, ll y) { if(x < y) x = y; }
inline void chmax(int &x, int y) { if(x < y) x = y; }
#define rei register int
#define rep(i, l, r) for(rei i = l, i##end = r; i <= i##end; ++ i)
#define per(i, r, l) for(rei i = r, i##end = l; i >= i##end; -- i)
const int N = 1e5 + 15, T = 30;
string s[N];
int lch[N * 2], rch[N * 2], sgtt;
int pos[N], L[N], R[N], tot;
int stk[N];
int a[N], n, q;
inline int lcp(string &a, string &b) {
int len = 0, n = min(a.size(), b.size());
for(; len < n && a[len] == b[len]; ++ len) ;
return len;
}
struct node {
int mx, cl, cr;
bool a1;
node(int _mx = 0, int _cl = 0, int _cr = 0, bool _a1 = 0)
: mx(_mx), cl(_cl), cr(_cr), a1(_a1) { }
} ;
node merge(const node &x, const node &y) {
return node(max(max(x.mx, y.mx), x.cr + y.cl),
x.cl + (x.a1 ? y.cl : 0), y.cr + (y.a1 ? x.cr : 0), x.a1 && y.a1);
}
int build(int l, int r) {
int p = ++ sgtt;
if(l == r) return p;
int mid = l + r >> 1;
lch[p] = build(l, mid);
rch[p] = build(mid + 1, r);
return p;
}
#define lc lch[p]
#define rc rch[p]
struct SSGT {
int t[N * 2];
inline void update(int p) {
t[p] = max(t[lc], t[rc]);
}
void build(int p, int l, int r) {
if(l == r) {
t[p] = s[l].size();
return ;
}
int mid = l + r >> 1;
build(lc, l, mid);
build(rc, mid + 1, r);
update(p);
}
void modify(int p, int l, int r, ci &tt) {
if(l == r) {
t[p] = s[l].size();
return ;
}
int mid = l + r >> 1;
if(tt <= mid) modify(lc, l, mid, tt);
else modify(rc, mid + 1, r, tt);
update(p);
}
int query(int p, int l, int r, ci &tl, ci &tr) {
if(tl <= l && r <= tr) return t[p];
int mid = l + r >> 1, res = 0;
if(tl <= mid) chmax(res, query(lc, l, mid, tl, tr));
if(mid < tr) chmax(res, query(rc, mid + 1, r, tl, tr));
return res;
}
} t2;
struct SGT {
int val;
node t[N * 2];
inline void update(int p) {
t[p] = merge(t[lc], t[rc]);
}
void build(int p, int l, int r) {
if(l == r) {
t[p].mx = t[p].cl = t[p].cr = t[p].a1 = a[l] >= val;
return ;
}
int mid = l + r >> 1;
build(lc, l, mid);
build(rc, mid + 1, r);
update(p);
}
void modify(int p, int l, int r, ci &tt) {
if(l == r) {
t[p].mx = t[p].cl = t[p].cr = t[p].a1 = a[l] >= val;
return ;
}
int mid = l + r >> 1;
if(tt <= mid) modify(lc, l, mid, tt);
else modify(rc, mid + 1, r, tt);
update(p);
}
node query(int p, int l, int r, ci &tl, ci &tr) {
if(tl <= l && r <= tr) return t[p];
int mid = l + r >> 1;
if(tr <= mid) return query(lc, l, mid, tl, tr);
else if(mid < tl) return query(rc, mid + 1, r, tl, tr);
else return merge(query(lc, l, mid, tl, tr), query(rc, mid + 1, r, tl, tr));
}
} rt[T + 1];
void init() {
int res = build(1, n);
t2.build(1, 1, n);
rep(i, 1, T) {
rt[i].val = i;
rt[i].build(1, 1, n);
}
rep(i, 1, n) if(a[i] > T) pos[++ tot] = i;
}
void modify(int i) {
int nv = lcp(s[i], s[i + 1]), x;
if(a[i] > T && nv <= T) {
x = lower_bound(pos + 1, pos + tot + 1, i) - pos;
rep(j, x, tot - 1) pos[j] = pos[j + 1];
tot -- ;
}
if(a[i] <= T && nv > T) {
x = lower_bound(pos + 1, pos + tot + 1, i) - pos;
tot ++ ;
per(j, tot, x + 1) pos[j] = pos[j - 1];
pos[x] = i;
}
swap(a[i], nv);
rep(v, 1, min(T, max(a[i], nv))) rt[v].modify(1, 1, n, i);
}
ll solve(int l, int r) {
ll res = 0;
int top;
rep(i, l, r) R[i] = r + 1;
stk[top = 1] = l;
rep(i, l + 1, r) {
for(; top && a[stk[top]] > a[i]; -- top)
R[stk[top]] = i;
stk[++ top] = i;
}
rep(i, l, r) L[i] = l - 1;
stk[top = 1] = r;
per(i, r - 1, l) {
for(; top && a[stk[top]] > a[i]; -- top)
L[stk[top]] = i;
stk[++ top] = i;
}
rep(i, l, r) chmax(res, 1ll * a[i] * (R[i] - L[i]));
return res;
}
ll query(int l, int r) {
ll res = 0, x;
rep(v, 1, T) {
x = rt[v].query(1, 1, n, l, r - 1).mx;
if(x) chmax(res, 1ll * v * (x + 1));
else break;
}
for(rei L = lower_bound(pos + 1, pos + tot + 1, l) - pos, R; L <= tot && pos[L] < r; L = R + 1) {
for(R = L; R < tot && pos[R + 1] < r && pos[R] + 1 == pos[R + 1]; ++ R) ;
chmax(res, solve(pos[L], pos[R]));
}
chmax(res, t2.query(1, 1, n, l, r));
return res;
}
signed main() {
int opt, l, r;
ios :: sync_with_stdio(0);
cin.tie(nullptr);
cin >> n >> q;
rep(i, 1, n) cin >> s[i];
rep(i, 1, n - 1) a[i] = lcp(s[i], s[i + 1]);
init();
for(; q; -- q) {
cin >> opt >> l;
if(opt == 1) {
cin >> r;
if(l == r) printf("%d\n", s[l].size());
else printf("%lld\n", query(l, r));
} else {
cin >> s[l];
t2.modify(1, 1, n, l);
if(l > 1) modify(l - 1);
if(l < n) modify(l);
}
}
return 0;
}