题解:qoj15502 字符串问题

听说是 runs 板子,但是不会也能做。

题意:给出一个字符串和权值序列 \(c\),定义子串 \([l,r]\) 的权值为 \(c_k\),这里 \(k\) 是子串的最小整周期的出现次数。现在问对于 \(i\in [1,n]\),以 \(i\) 为右端点的子串的权值之和。\(n\le 10^6\)

做法:

首先求刚好出现 \(k\) 次的有点麻烦,我们考虑把权值稍微改改,使得变成出现 \(d\mid k\) 次就贡献 \(f_d\),这样就不要求是最小整周期了而是周期即可。也就是我们需要求一个 \(f\) 使得 \(c_n = \sum\limits_{d\mid n}f_d\),直接类似筛法的方式算算贡献即可,复杂度 \(O(n\log n)\)

然后考虑怎么算答案,对于出现一次的周期不好处理,我们认为每个子串都带了一个自身周期,所以可以直接加,考虑出现两次及以上的周期即可。

我们枚举周期长度,直接考虑每个端开始的话那还是 \(O(n)\) 的枚举,我们其实很想 \(O(\frac n i)\) 地枚举。其实是可行的,我们只要用一个经典 trick。我们对于 $i,2i\cdots $ 的地方撒一个关键点,那么如果一个周期出现了至少两次,那么他们一定覆盖了至少两个关键点。我们枚举是哪两个相邻关键点,并且求出以他们开头和结尾的公共 lcp/lcs 长度,也就是我们找到了如下图的一个相等情况:

我们发现这两段是相等的,且都以 \(i\) 为周期,记第一段中为 \([l,r]\)。注意,这里要求红色段长度一定要小于 \(i\),否则会被多次统计算多。

那么我们发现,对于 \([l+i-1,l+2i-2]\) 这一部分贡献为 \(f_1\)\([l+2i-1,l+3i-2]\) 这一部分为 \(f_1+f_2\),以此类推。发现其实等于若干个区间加。因为每个位置在 \(i\) 长的时候最多只会被 cover,所以加的区间个数是 \(O(n\log n)\) 的,可以用差分处理。

现在的问题是怎么找任意两个位置开始/结尾的 lcp 和 lcs,这个可以用你喜欢的字符串数据结构维护,我喜欢 SAM 就写了 SAM,结果被卡了点空间,把欧拉序 lca 换成 dfs 序 lca 就可以了。

代码:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e6 + 5, mod = 998244353;
int n, a[maxn], v;
char c[maxn];
struct node {
	int lnk, len, nxt[26];
} ;
struct SAM {
	node tr[maxn * 2];
	int tot = 1, lst = 1, pos[maxn];
	void add(int c, int id) {
		int cur = ++tot, p = lst;
		pos[id] = cur;
		tr[cur].len = tr[p].len + 1;
		while(p && !tr[p].nxt[c])
			tr[p].nxt[c] = cur, p = tr[p].lnk;
		if(!p)
			tr[cur].lnk = 1;
		else {
			int q = tr[p].nxt[c];
			if(tr[p].len + 1 == tr[q].len)
				tr[cur].lnk = q;
			else {
				int clone = ++tot; tr[clone] = tr[q];
				tr[clone].len = tr[p].len + 1;
				while(p && tr[p].nxt[c] == q)
					tr[p].nxt[c] = clone, p = tr[p].lnk;
				tr[q].lnk = tr[cur].lnk = clone;
			}
		}
		lst = cur;	
	}	
	vector<int> e[maxn * 2];
	void prepare() {
		for (int i = 1; i <= tot; i++)
			e[tr[i].lnk].push_back(i);
	//	cout << tr[2].lnk << endl;
	}
	vector<int> st[23];
	int cnt, dep[maxn * 2], lg[maxn * 4], dfn[maxn * 2];
	int get_min(int u, int v) {
		return (dfn[u] < dfn[v] ? u : v);
	}
	void dfs(int u) {
		dfn[u] = ++cnt; 
		for (int i = 0; i < e[u].size(); i++) {
			int v = e[u][i];
			//cout << u << " " << v << " " << endl;
			st[0].push_back(u);
			dfs(v);
		}
	}
	void preparest() {
		lg[0] = -1;
		cnt = st[0].size();
		for (int i = 1; i <= cnt; i++)
			lg[i] = lg[i >> 1] + 1;
		for (int j = 1; (1 << j) <= cnt; j++) {
			st[j].push_back(0);
			for (int i = 1; i + (1 << j) - 1 <= cnt; i++)
				st[j].push_back(get_min(st[j - 1][i], st[j - 1][i + (1 << j - 1)]));
		}
	}
	void build() {
		prepare();
		st[0].push_back(0); st[0].push_back(0);
		dfs(1);
		preparest();
	}
	int lca(int x, int y) {
		x = pos[x], y = pos[y];
		x = dfn[x], y = dfn[y];
		if(x > y)
			swap(x, y);
		x++;
		int k = lg[y - x + 1];
	//	cout << x << " " << y << " " << get_min(st[k][x], st[k][y - (1 << k) + 1]) << endl;
		return tr[get_min(st[k][x], st[k][y - (1 << k) + 1])].len;
	}
} PS, SS;
void preparea() {
	v = a[1];
	for (int i = 1; i <= n; i++) {
		for (int j = 2 * i; j <= n; j += i)
			a[j] = (a[j] - a[i] + mod) % mod;
	}
	a[1] = 0;
}
void preparec() {
	for (int i = 1; i <= n; i++)
		SS.add(c[i] - 'a', i);
	for (int i = n; i >= 1; i--)
		PS.add(c[i] - 'a', i);
	SS.build(), PS.build();
}
int cf[maxn];
signed main() {
	cin >> n;
	for (int i = 1; i <= n; i++)
		cin >> c[i];
	for (int i = 1; i <= n; i++)
		cin >> a[i];
	preparea();
	preparec();
	for (int i = 1; i <= n; i++)
		a[i] = (a[i - 1] + a[i]) % mod;
	for (int i = 1; 2 * i <= n; i++) {
		for (int x = i, y = 2 * i; y <= n; x += i, y += i) {
			if(c[x] != c[y])
				continue;
			int pre = PS.lca(x, y), suf = SS.lca(x, y);
			swap(pre, suf);
		//	cout << x << " " << y << " " << pre << " " << suf << endl;
			int len = pre + suf + y - x - 1;
			if(len >= 2 * i && pre <= i) {
		//		cout << x << " " << y << endl;
				int lx = x - pre + 1, rx = y + suf - 1, nw = lx + i - 1;
				for (int cnt = 1; nw <= rx; nw += i, cnt++) {
					cf[nw] = (cf[nw] + a[cnt]) % mod;
					cf[min(rx + 1, nw + i)] = (cf[min(rx + 1, nw + i)] - a[cnt] + mod) % mod;
				}
			}
		}
	}
	for (int i = 1; i <= n; i++) {
		cf[i] = (cf[i] + cf[i - 1]) % mod;
		cout << (cf[i] + 1ll * v * i % mod) % mod << " ";
	}
	return 0;
}
/*
8
babaaabb
0 1 1 0 0 0 0 0

*/
posted @ 2025-12-19 21:55  LUlululu1616  阅读(3)  评论(0)    收藏  举报