题解 CF862F Mahmoud and Ehab and the final stage

首先令 \(a_i=\operatorname{lcp}(s_i,s_{i+1})\),原问题转化成求 \([l,r]\)\(\min\limits_{L\le i< R}\{s_i\}\times(R-L+1)\) 的最大值,单点修改。对于 \(L=R\) 的部分就是求最大长度,显然线段树可以解决。剩下的就是求 \(\min\limits_{L\le i\le R}\{s_i\}\times(R-L+2)\) 的最大值,\(l\le L\le R<r\)

注意到一个很重要的性质:字符串长度总和不超过 \(10^5\),即 \(\sum a_i\le 10^5\)。所以考虑根号分治。设定阈值 \(T\),把答案按照区间最小值的大小分成两类。

  1. 区间最小值 \(\le T\)。假设已经知道最小值为 \(x\),我们只关心某个数是否大于等于 \(x\)。那么把大于等于 \(x\) 的数标为 \(1\),剩下的标为 \(0\),只需要求出区间 \([l,r-1]\) 内的最长的只包含 \(1\) 的连续段。用 \(T\) 棵线段树维护即可。

  2. 区间最小值 \(>T\)。这样的位置最多 \(\dfrac{10^5}{T}\) 个。假设其中存在一个连续区间 \([p,q]\) 并且是 \([L,R]\) 的子区间,对于其中的每一个 \(i\),用单调栈求出 \(lp_i,rp_i\),分别表示 \(i\) 往左 / 往右第一个小于 \(a_i\) 的位置。那么以 \(a_i\) 为最小值的区间的答案就是 \(a_i\times(rp_i-lp_i)\)

总时间复杂度 \(O(q(\dfrac{\sum|s_i|}{T}+T\log n))\),取 \(T=\sqrt{\dfrac{\sum|s_i|}{\log n}}\) 时最优,为 \(O(q\sqrt{(\sum|s_i|)\log n})\)

#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <string>

using namespace std;

typedef long long ll;
typedef pair <int, int> pii;
#define ci const int
inline int max(ci &x, ci &y) { return x > y ? x : y; }
inline int min(ci &x, ci &y) { return x < y ? x : y; }
inline void swap(int &x, int &y) { x ^= y ^= x ^= y; }
inline void chmax(ll &x, ll y) { if(x < y) x = y; }
inline void chmax(int &x, int y) { if(x < y) x = y; }
#define rei register int
#define rep(i, l, r) for(rei i = l, i##end = r; i <= i##end; ++ i)
#define per(i, r, l) for(rei i = r, i##end = l; i >= i##end; -- i)
const int N = 1e5 + 15, T = 30;
string s[N];
int lch[N * 2], rch[N * 2], sgtt;
int pos[N], L[N], R[N], tot;
int stk[N];
int a[N], n, q;

inline int lcp(string &a, string &b) {
	int len = 0, n = min(a.size(), b.size());
	for(; len < n && a[len] == b[len]; ++ len) ;
	return len;
}

struct node {
	int mx, cl, cr;
	bool a1;
	node(int _mx = 0, int _cl = 0, int _cr = 0, bool _a1 = 0)
		: mx(_mx), cl(_cl), cr(_cr), a1(_a1) { }
} ;

node merge(const node &x, const node &y) {
	return node(max(max(x.mx, y.mx), x.cr + y.cl),
		x.cl + (x.a1 ? y.cl : 0), y.cr + (y.a1 ? x.cr : 0), x.a1 && y.a1);
}

int build(int l, int r) {
	int p = ++ sgtt;
	if(l == r) return p;
	int mid = l + r >> 1;
	lch[p] = build(l, mid);
	rch[p] = build(mid + 1, r);
	return p;
}

#define lc lch[p]
#define rc rch[p]

struct SSGT {
	int t[N * 2];

	inline void update(int p) {
		t[p] = max(t[lc], t[rc]);
	}

	void build(int p, int l, int r) {
		if(l == r) {
			t[p] = s[l].size();
			return ;
		}
		int mid = l + r >> 1;
		build(lc, l, mid);
		build(rc, mid + 1, r);
		update(p);
	}

	void modify(int p, int l, int r, ci &tt) {
		if(l == r) {
			t[p] = s[l].size();
			return ;
		}
		int mid = l + r >> 1;
		if(tt <= mid) modify(lc, l, mid, tt);
		else modify(rc, mid + 1, r, tt);
		update(p);
	}

	int query(int p, int l, int r, ci &tl, ci &tr) {
		if(tl <= l && r <= tr) return t[p];
		int mid = l + r >> 1, res = 0;
		if(tl <= mid) chmax(res, query(lc, l, mid, tl, tr));
		if(mid < tr) chmax(res, query(rc, mid + 1, r, tl, tr));
		return res;
	}

} t2;

struct SGT {
	int val;
	node t[N * 2];

	inline void update(int p) {
		t[p] = merge(t[lc], t[rc]);
	}

	void build(int p, int l, int r) {
		if(l == r) {
			t[p].mx = t[p].cl = t[p].cr = t[p].a1 = a[l] >= val;
			return ;
		}
		int mid = l + r >> 1;
		build(lc, l, mid);
		build(rc, mid + 1, r);
		update(p);
	}

	void modify(int p, int l, int r, ci &tt) {
		if(l == r) {
			t[p].mx = t[p].cl = t[p].cr = t[p].a1 = a[l] >= val;
			return ;
		}
		int mid = l + r >> 1;
		if(tt <= mid) modify(lc, l, mid, tt);
		else modify(rc, mid + 1, r, tt);
		update(p);
	}

	node query(int p, int l, int r, ci &tl, ci &tr) {
		if(tl <= l && r <= tr) return t[p];
		int mid = l + r >> 1;
		if(tr <= mid) return query(lc, l, mid, tl, tr);
		else if(mid < tl) return query(rc, mid + 1, r, tl, tr);
		else return merge(query(lc, l, mid, tl, tr), query(rc, mid + 1, r, tl, tr));
	}

} rt[T + 1];

void init() {
	int res = build(1, n);
	t2.build(1, 1, n);
	rep(i, 1, T) {
		rt[i].val = i;
		rt[i].build(1, 1, n);
	}
	rep(i, 1, n) if(a[i] > T) pos[++ tot] = i;
}

void modify(int i) {
	int nv = lcp(s[i], s[i + 1]), x;
	if(a[i] > T && nv <= T) {
		x = lower_bound(pos + 1, pos + tot + 1, i) - pos;
		rep(j, x, tot - 1) pos[j] = pos[j + 1];
		tot -- ;
	}
	if(a[i] <= T && nv > T) {
		x = lower_bound(pos + 1, pos + tot + 1, i) - pos;
		tot ++ ;
		per(j, tot, x + 1) pos[j] = pos[j - 1];
		pos[x] = i;
	}
	swap(a[i], nv);
	rep(v, 1, min(T, max(a[i], nv))) rt[v].modify(1, 1, n, i);
}

ll solve(int l, int r) {
	ll res = 0;
	int top;
	rep(i, l, r) R[i] = r + 1;
	stk[top = 1] = l;
	rep(i, l + 1, r) {
		for(; top && a[stk[top]] > a[i]; -- top)
			R[stk[top]] = i;
		stk[++ top] = i;
	}
	rep(i, l, r) L[i] = l - 1;
	stk[top = 1] = r;
	per(i, r - 1, l) {
		for(; top && a[stk[top]] > a[i]; -- top)
			L[stk[top]] = i;
		stk[++ top] = i;
	}
	rep(i, l, r) chmax(res, 1ll * a[i] * (R[i] - L[i]));
	return res;
}

ll query(int l, int r) {
	ll res = 0, x;
	rep(v, 1, T) {
		x = rt[v].query(1, 1, n, l, r - 1).mx;
		if(x) chmax(res, 1ll * v * (x + 1));
		else break;
	}
	for(rei L = lower_bound(pos + 1, pos + tot + 1, l) - pos, R; L <= tot && pos[L] < r; L = R + 1) {
		for(R = L; R < tot && pos[R + 1] < r && pos[R] + 1 == pos[R + 1]; ++ R) ;
		chmax(res, solve(pos[L], pos[R]));
	}
	chmax(res, t2.query(1, 1, n, l, r));
	return res;
}

signed main() {
	int opt, l, r;
	ios :: sync_with_stdio(0);
	cin.tie(nullptr);
	cin >> n >> q;
	rep(i, 1, n) cin >> s[i];
	rep(i, 1, n - 1) a[i] = lcp(s[i], s[i + 1]);
	init();
	for(; q; -- q) {
		cin >> opt >> l;
		if(opt == 1) {
			cin >> r;
			if(l == r) printf("%d\n", s[l].size());
			else printf("%lld\n", query(l, r));
		} else {
			cin >> s[l];
			t2.modify(1, 1, n, l);
			if(l > 1) modify(l - 1);
			if(l < n) modify(l);
		}
	}
	return 0;
}
posted @ 2021-09-15 19:27  EverlastingEternity  阅读(42)  评论(0编辑  收藏  举报