Loading

CF917E 做题记录

让我深感畏惧的题目。link

考虑将答案分为两条链上的,以及跨越 lca 的。

对于两条链上的,对所有串正反串一起建 AC 自动机,树上每个点求出从根到该点的匹配状态,在 AC 自动机上对应一个结点。

对于每个询问,相当于在链上求满足的匹配状态结点在 fail 树上是否在某个点的子树内的点的个数,可以二维数点。

关键问题在于求解跨越 lca 的,考虑将每个询问挂在对应的 \(s\) 处。

记询问 \((u, v)\) 的 LCA 为 \(f\),那么我们想统计答案需要求出最长的串的长度满足其是 \(f \to u\) 的前缀且是 \(s\) 的后缀,对于 \(f \to v\) 同理,这样可以在 \(s\) 的正反串的失配树上做二维数点来求解答案。

求解最长长度依然有难度,考虑使用后缀数组

这里讲述一下这个我没见过的新做法:

  • \(s\) 建立后缀数组;

  • \(f \to u\) 组成的字符串在 \(s\) 的后缀数组上二分,用倍增求 lcp;

  • 找到 lcp 最大时对应的 \(s\) 的后缀 \(s[p : n]\),此时我们要找的答案后缀长度一定不超过此时 lcp 长度;

  • 所以要找的答案后缀是 \(s[p : n]\) 的一个后缀。又因为其不超过 lcp 长度,所以也是 \(s[p : n]\) 的一个前缀;

  • 在失配树上倍增即可找到答案后缀。

感觉这个东西很厉害,我们得到了一种新的可以实现 Z 函数的功能且不依赖均摊的算法。

\(n, m, \sum |s|\) 同级,时间复杂度 \(\mathcal O(n\log ^ 2n)\)

点击查看代码
#include <bits/stdc++.h>
#define ll int
#define LL long long
#define pb push_back
#define pir pair <ll, ll>
#define mkp make_pair
#define fi first
#define se second
#define i128 __int128
using namespace std;
template <class T>
void rd(T &x) {
    char ch; bool f = 0;
    while(!isdigit(ch = getchar()))
        if(ch == '-') f = 1;
    x = ch - '0';
    while(isdigit(ch = getchar()))
        x = (x << 1) + (x << 3) + ch - '0';
    if(f) x = -x;
}
const ll maxn = 2e5 + 10, M = 1e6 + 10, inf = 1e9 + 5, mod = 1e9 + 7, iv = mod - mod / 2;
const LL INF = 1e18 + 5;
ll power(ll a, ll b = mod - 2, ll p = mod) {
    ll s = 1;
    while(b) {
        if(b & 1) s = 1ll * s * a % p;
        a = 1ll * a * a % p, b >>= 1;
    } return s;
}
template <class T1, class T2>
void add(T1 &x, const T2 y) { x = x + y >= mod? x + y - mod : x + y; }
template <class T1, class T2>
void sub(T1 &x, const T2 y) { x = x < y? x + mod - y : x - y; }
template <class T1, class T2>
ll pls(const T1 x, const T2 y) { return x + y >= mod? x + y - mod : x + y; }
template <class T1, class T2>
ll mus(const T1 x, const T2 y) { return x < y? x + mod - y : x - y; }
template <class T1, class T2>
void chkmin(T1 &a, const T2 b) { a = a < b? a : b; }
template <class T1, class T2>
void chkmax(T1 &a, const T2 b) { a = a < b? b : a; }

ll n, m, q, sa[maxn], rk[maxn], cnt[maxn], id[maxn], oldrk[maxn], len[maxn];
ll ans[maxn], a[maxn], b[maxn], nxt[maxn][20], d[maxn][20], dep[maxn], l, f[maxn];
char *s[maxn], str[maxn], w[maxn];
vector <pair <ll, char> > to[maxn];
vector <pair <pir, ll> > vec[maxn];

ll dat[maxn], ti;
void upd(ll x, ll v) { for(; x <= ti; x += x & -x) dat[x] += v; }
ll ask(ll x) { ll v = 0; for(; x; x -= x & -x) v += dat[x]; return v; }

const ll mod1 = 998244353, mod2 = 1e9 + 7;
struct Data { ll x, y; Data(ll a = 0, ll b = 0) { x = a, y = b; } } h[maxn], pw[maxn], ts[maxn];
const Data operator + (const Data A, const Data B) {
 return Data(A.x + B.x >= mod1? A.x + B.x - mod1 : A.x + B.x,
 A.y + B.y >= mod2? A.y + B.y - mod2 : A.y + B.y); }
const Data operator - (const Data A, const Data B) {
 return Data(A.x < B.x? A.x - B.x + mod1 : A.x - B.x,
 A.y < B.y? A.y - B.y + mod2 : A.y - B.y); }
const Data operator * (const Data A, const Data B) {
 return Data(1ll * A.x * B.x %mod1, 1ll * A.y * B.y %mod2); }
const bool operator == (const Data A, const Data B) { return A.x == B.x && A.y == B.y; }
const bool operator != (const Data A, const Data B) { return A.x != B.x || A.y != B.y; }

void dfs(ll u, ll fa = 0) {
	dep[u] = dep[d[u][0] = fa] + 1;
	for(ll i = 1; i < 20; i++) d[u][i] = d[d[u][i - 1]][i - 1];
	for(auto e: to[u]) {
		ll v = e.fi;
		if(v ^ fa) {
			w[v] = e.se;
			ts[v] = ts[u] * Data(131, 13331) + Data(w[v], w[v]);
			dfs(v, u);
		}
	}
}
ll lca(ll u, ll v) {
	if(dep[u] < dep[v]) swap(u, v);
	ll t = dep[u] - dep[v];
	for(ll i = 19; ~i; i--)
		if(t & (1 << i)) u = d[u][i];
	if(u == v) return u;
	for(ll i = 19; ~i; i--)
		if(d[u][i] ^ d[v][i]) u = d[u][i], v = d[v][i];
	return d[u][0];
}
ll jump(ll u, ll k) {
	for(ll i = 0; i < 20; i++)
		if(k & (1 << i)) u = d[u][i];
	return u;
}

namespace ACAM {
	ll q[maxn], h, t, tot = 1, trie[maxn][26], fail[maxn], ed[maxn];
	ll dfn[maxn], out[maxn]; vector <ll> g[maxn];
	void dfs(ll u) {
		dfn[u] = ++ti;
		for(ll v: g[u]) dfs(v);
		out[u] = ti;
	}
	void build() {
		h = 1, t = 0;
		for(ll i = 0; i < 26; i++)
			if(!trie[1][i]) trie[1][i] = 1;
			else fail[q[++t] = trie[1][i]] = 1;
		while(h <= t) {
			ll u = q[h++]; g[fail[u]].pb(u);
			for(ll i = 0; i < 26; i++)
				if(!trie[u][i]) trie[u][i] = trie[fail[u]][i];
				else fail[q[++t] = trie[u][i]] = trie[fail[u]][i];
		}
		dfs(1);
	}
	struct info { ll k, w, id; }; vector <info> sec[maxn];
	void tdfs(ll u, ll p = 1, ll fa = 0) {
		upd(dfn[p], 1);
		for(info t: sec[u]) ans[t.id] += t.w * (ask(out[ed[t.k]]) - ask(dfn[ed[t.k]] - 1));
		for(auto e: to[u]) {
			ll v = e.fi;
			if(v ^ fa) {
				tdfs(v, trie[p][e.se - 'a'], u);
			}
		}
		upd(dfn[p], -1);
	}
	void solve() {
		for(ll i = 1; i <= m; i++) {
			ll p = 1;
			for(ll j = 1; j <= len[i]; j++) {
				ll ch = s[i][j] - 'a';
				if(!trie[p][ch]) trie[p][ch] = ++tot;
				p = trie[p][ch];
			} ed[i] = p;
			p = 1;
			for(ll j = len[i]; j; j--) {
				ll ch = s[i][j] - 'a';
				if(!trie[p][ch]) trie[p][ch] = ++tot;
				p = trie[p][ch];
			} ed[i + m] = p;
			for(auto t: vec[i]) {
				ll u = t.fi.fi, v = t.fi.se, j = t.se;
				if(dep[u] - dep[f[j]] >= len[i]) {
					ll x = jump(u, dep[u] - dep[f[j]] - len[i] + 1);
					sec[u].pb((info) {i + m, 1, j});
					sec[x].pb((info) {i + m, -1, j});
				}
				if(dep[v] - dep[f[j]] >= len[i]) {
					ll x = jump(v, dep[v] - dep[f[j]] - len[i] + 1);
					sec[v].pb((info) {i, 1, j});
					sec[x].pb((info) {i, -1, j});
				}
			}
		}
		build();
		dfs(1);
		tdfs(1);
	}
}

void Build(vector <ll> *g) {
	for(ll i = 0; i < 26; i++) cnt[i] = 0;
	for(ll i = 1; i <= l; i++) ++cnt[str[i] - 'a'];
	for(ll i = 1; i < 26; i++) cnt[i] += cnt[i - 1];
	for(ll i = l; i; i--) sa[cnt[str[i] - 'a']--] = i;
	ll p = 0;
	for(ll i = 1; i <= l; i++)
		if(str[sa[i]] == str[sa[i - 1]]) rk[sa[i]] = p;
		else rk[sa[i]] = ++p;
	for(ll w = 1; w < l; w <<= 1) {
		p = 0;
		for(ll i = 0; i < w; i++) id[++p] = l - i;
		for(ll i = 1; i <= l; i++)
			if(sa[i] > w) id[++p] = sa[i] - w;
		for(ll i = 1; i <= l; i++) cnt[i] = 0;
		for(ll i = 1; i <= l; i++) ++cnt[oldrk[i] = rk[i]];
		for(ll i = 1; i <= l; i++) cnt[i] += cnt[i - 1];
		for(ll i = l; i; i--) sa[cnt[rk[id[i]]]--] = id[i];
		p = 0;
		for(ll i = 1; i <= l; i++)
			if(oldrk[sa[i]] == oldrk[sa[i - 1]] &&
			 oldrk[sa[i] + w] == oldrk[sa[i - 1] + w]) rk[sa[i]] = p;
			else rk[sa[i]] = ++p;
	}
	for(ll i = 1; i <= l; i++) h[i] = h[i - 1] * Data(131, 13331) + Data(str[i], str[i]);
	for(ll i = 0; i < 20; i++) nxt[l + 1][i] = nxt[l][i] = l + 1;
	g[nxt[l][0]].pb(l);
	for(ll i = l - 1, j = l + 1; i; i--) {
		while(j <= l && str[j - 1] != str[i]) j = nxt[j][0];
		if(str[j - 1] == str[i]) nxt[i][0] = --j;
		else nxt[i][0] = l + 1;
		g[nxt[i][0]].pb(i);
		for(ll k = 1; k < 20; k++) nxt[i][k] = nxt[nxt[i][k - 1]][k - 1];
	}
}
vector <ll> to1[maxn], to2[maxn];

Data Tree(ll u, ll f) { return ts[u] - ts[f] * pw[dep[u] - dep[f]]; }
Data Seq(ll l, ll r) { return h[r] - h[l - 1] * pw[r - l + 1]; }

ll slen[maxn];
ll find(ll u, ll f) {
	ll lo = 1, hi = l;
	while(lo <= hi) {
		ll mid = lo + hi >> 1, x = u, p = sa[mid];
		if(l - p + 1 >= dep[u] - dep[f] && Tree(u, f) == Seq(p, p + dep[u] - dep[f] - 1)) {
			slen[mid] = dep[u] - dep[f];
			hi = mid - 1; continue;
		}
		for(ll i = 19; ~i; i--)
			if(dep[d[x][i]] >= dep[f] && (dep[d[x][i]] - dep[f] > l - p + 1 ||
			 Tree(d[x][i], f) != Seq(p, p + dep[d[x][i]] - dep[f] - 1)))
				x = d[x][i];
		slen[mid] = dep[x] - dep[f] - 1;
		if(w[x] > str[p + dep[x] - dep[f] - 1]) lo = mid + 1;
		else hi = mid - 1;
	} return lo;
}

ll Get(ll u, ll lim) {
	if(l - u + 1 <= lim) return u;
	for(ll i = 19; ~i; i--)
		if(l - nxt[u][i] + 1 > lim) u = nxt[u][i];
	return nxt[u][0];
}

ll dfn[maxn], out[maxn];
void dfs1(ll u) {
	dfn[u] = ++ti;
	for(ll v: to1[u]) dfs1(v);
	out[u] = ti;
}
vector <ll> mdf[maxn];
vector <pir> qur[maxn];
void dfs2(ll u) {
	for(ll v: mdf[u]) upd(dfn[v], 1), upd(out[v] + 1, -1);
	for(pir t: qur[u]) ans[t.se] += ask(dfn[t.fi]);
	for(ll v: to2[u]) dfs2(v);
	for(ll v: mdf[u]) upd(dfn[v], -1), upd(out[v] + 1, 1);
}

int main() {
	rd(n), rd(m), rd(q); pw[0] = Data(1, 1);
	for(ll i = 1; i <= 1e5; i++) pw[i] = pw[i - 1] * Data(131, 13331);
	for(ll i = 1; i < n; i++) {
		ll u, v; rd(u), rd(v);
		char ch[2]; scanf("%s", ch);
		to[u].pb(mkp(v, ch[0])), to[v].pb(mkp(u, ch[0]));
	} dfs(1);
	for(ll i = 1; i <= m; i++) {
		scanf("%s", str + 1); len[i] = strlen(str + 1);
		s[i] = new char[len[i] + 1];
		for(ll j = 1; j <= len[i]; j++) s[i][j] = str[j];
	}
	for(ll i = 1; i <= q; i++) {
		ll u, v, p; rd(u), rd(v), rd(p);
		vec[p].pb(mkp(mkp(u, v), i)), f[i] = lca(u, v);
		a[i] = b[i] = 1e5 + 5;
	}
	ACAM::solve();
	for(ll i = 1; i <= m; i++) {
		if(vec[i].empty()) continue;
		l = len[i];
		for(ll j = 1; j <= l; j++) str[j] = s[i][j]; str[l + 1] = 0;
		for(ll j = 1; j <= l + 1; j++) to1[j].clear(), to2[j].clear();
		Build(to2);
		for(auto t: vec[i]) {
			ll u = t.fi.fi, v = t.fi.se, j = t.se;
			ll pos = find(v, f[j]);
			if(pos <= l) chkmin(b[j], Get(sa[pos], slen[pos]));
			if(pos > 1) chkmin(b[j], Get(sa[pos - 1], slen[pos - 1]));
		}
		reverse(str + 1, str + 1 + l);
		Build(to1);
		for(ll i = 1; i <= l + 1; i++) qur[i].clear(), mdf[i].clear();
		for(auto t: vec[i]) {
			ll u = t.fi.fi, v = t.fi.se, j = t.se;
			ll pos = find(u, f[j]);
			if(pos <= l) chkmin(a[j], Get(sa[pos], slen[pos]));
			if(pos > 1) chkmin(a[j], Get(sa[pos - 1], slen[pos - 1]));
			qur[b[j]].pb(mkp(a[j], j));
		}
		for(ll i = 1; i < l; i++) mdf[i + 1].pb(l - i + 1);
		ti = 0, dfs1(l + 1);
		dfs2(l + 1);
	}
	for(ll i = 1; i <= q; i++) printf("%d\n", ans[i]);
    return 0;
}
posted @ 2026-01-24 16:19  Sktn0089  阅读(3)  评论(0)    收藏  举报