P8306 【模板】字典树

P8306 【模板】字典树题解

字典树是个啥东西?

字典树,英文名 trie 。顾名思义,就是一个像字典一样的树。
字典树,查找一个字符串是否在 "字典" 中出现过,是一种以空间换时间的算法 (所以这一道题给了1 GB 的内存限制)

用一个例子简单理解一下字典树

就本题而言,需要我们查找有多少个 \(t_i\)\(s_j\) 的前缀,举个例子,当输入的字符串为

abc acc abd da

此时,可以构建出以下的一棵树,并且以每个字符出现顺序编号

那么这个时候我们来思考,如何判断一个字符串是否是另一个字符串的前缀呢?又有多少个字符串的前缀是它呢?
这个时候我们想到,可以在每一个节点上使用 \(cnt_i\) 来记录有多少个字符串曾经 "走到" 过编号为 \(i\) 的节点,现在这棵树就变成了这样(注:左边数字为 \(cnt_i\) ,右边数字为其编号)

由此,假设要询问多少个字符串前缀是 \(ab\) 便是对应着 \(cnt_2\) 的数值,则为 2

现在开始实现代码

1.为字符编号

int getID(char ch) {
	if (islower(ch))return ch - 'a';//小写
	else if (isupper(ch))return ch - 'A' + 26;//大写
	else return ch - '0' + 52;//数字
}

2.插入新的字符串

void insert(string &s) {
	int x = 0;  //开始以根节点出发
	for (int i = 0; i < (int)s.size(); i++) {
		int j = getID(s[i]);  //获取将字符转化为数字之后的编号
		if (!tre[x][j])tre[x][j] = ++idx;//如果未被插入,插入并标记编号
		x = tre[x][j];//跳到下一个节点
		cnt[x]++;//记录"来过"
	}
}

3.查找是否是某一个字符串前缀

int find(string &s) {
	int x = 0;
	for (int i = 0; i < (int)s.size(); i++) {
		int j = getID(s[i]);
		if (!tre[x][j])return 0;//如果没有节点相同,说明没有任何一个前缀是它
		x = tre[x][j];
	}
	return cnt[x];//返回对应的次数
}

好了,把以上内容整合起来, \(AC\) 代码附上

#include<bits/stdc++.h>
using namespace std;
const int N = 3000005;
int tre[N][70];
int T;
int n, q;
string s;
int cnt[N];
int idx;
void write(long long k) {
	if (k < 0) {
		putchar('-');
		k = -k;
	}
	char st[21];
	int top = 0;
	do {
		st[++top] = (k % 10) | 48, k /= 10;
	} while (k);
	while (top)putchar(st[top--]);
}
int getID(char ch) {
	if (islower(ch))return ch - 'a';
	else if (isupper(ch))return ch - 'A' + 26;
	else return ch - '0' + 52;
}
void insert(string &s) {
	int x = 0;
	for (int i = 0; i < (int)s.size(); i++) {
		int j = getID(s[i]);
		if (!tre[x][j])tre[x][j] = ++idx;
		x = tre[x][j];
		cnt[x]++;
	}
}
int find(string &s) {
	int x = 0;
	for (int i = 0; i < (int)s.size(); i++) {
		int j = getID(s[i]);
		if (!tre[x][j])return 0;
		x = tre[x][j];
	}
	return cnt[x];
}
int main() {
	cin >> T;
	while (T--) {
		cin >> n >> q;
		for (int i = 1; i <= n; i++) {
			cin >> s;
			insert(s);
		}
		for (int i = 1; i <= q; i++) {
			cin >> s;
			write(find(s));
			printf("\n");
		}
		for (int i = 0; i <= idx; i++)
			for (int j = 0; j <= 70; j++)
				tre[i][j] = 0;
		for (int i = 0; i <= idx; i++)
			cnt[i] = 0;
		idx = 0;
	}
	return 0;
}

等一等!这题没完!要卡常,千万不要 \(memset()\) 啊!!

posted @ 2025-06-08 22:43  ppi_SAMA  阅读(27)  评论(0)    收藏  举报