class Trie {
private:
struct Node {
int end, cnt;
int nxt[62];
};
vector<Node> trie;
public:
Trie() {
init();
}
void init() {
trie.assign(1, Node());
}
void insert(const string &str) {//插入字符串
update(str, 0, 0, 1);
}
void erase(const string &str) {//删除字符串
update(str, 0, 0, -1);
}
int count(const string &str) {//统计trie中字符串s的个数
return query(str, 0, 0, 1);
}
int find(const string &str) {//统计trie中以字符串s为前缀的个数
return query(str, 0, 0, 0);
}
private:
void update(const string &str, int cur, int idx, int f) {
trie[cur].cnt += (cur != 0 ? f : 0);
if(idx == str.size()) {
trie[cur].end += 1;
return;
}
int x = trans(str[idx]);
if(!trie[cur].nxt[x]) {
trie[cur].nxt[x] = trie.size();
trie.push_back(Node());
}
update(str, trie[cur].nxt[x], idx + 1, f);
}
int query(const string &str, int cur, int idx, int f) {
if(idx == str.size()) {
return f ? trie[cur].end : trie[cur].cnt;
}
int x = trans(str[idx]);
if(!trie[cur].nxt[x]) {
return 0;
}
return query(str, trie[cur].nxt[x], idx + 1, f);
}
int trans(char c) {//字符映射函数
return c - (isdigit(c) ? -4 : (isupper(c) ? 39 : 97));
}
};