Trie in C++

问题背景:
假设数据库中存有搜索词条和对应的搜索频度, 当用户输入某一串字符(不区分大小写)时, 需要输出以输入为前缀的若干搜索频度最大的词条。
下面是基于Trie的算法实现。 基于Trie的实现的好处是查询效率高, 支持动态查询(能快速更新数据库)。
C++ 源代码:
  1 #include <iostream>
  2 #include <cstring>
  3 #include <string>
  4 #include <cctype>
  5 #include <queue>
  6 #include <utility>
  7 using namespace std;
  8 
  9 const int n_ascii = 256;
 10 
 11 struct trie_node {
 12     trie_node* ptrs[n_ascii];
 13     int freq;
 14     trie_node (int f = 0) : freq(f) {
 15         memset(ptrs, 0, sizeof(ptrs));
 16     }
 17 };
 18 
 19 class trie {
 20 private:
 21     trie_node* root;
 22     
 23 public:
 24     trie() : root(new trie_node()) {}
 25     ~trie() { release(root); }
 26     
 27     void insert(const string& word, int freq) {
 28         trie_node* prev = root;
 29         for (auto iter = word.begin(); iter != word.end(); ++iter) {
 30             int ch = *iter;
 31             trie_node* node = prev->ptrs[ch];
 32             if (!node) {
 33                 prev->ptrs[ch] = node = new trie_node(); 
 34             }
 35             prev = node;
 36         }
 37         prev->freq = freq;
 38     }
 39 
 40     typedef pair<trie_node*, string> search_result_type;
 41     inline const vector<search_result_type> search(string prefix, bool case_sensitive = false) {
 42         if (case_sensitive) {
 43             return search_case_sensitive(prefix);    
 44         } else {
 45             return search_case_insensitive(prefix);
 46         }
 47     }
 48     
 49     const vector<search_result_type> search_case_sensitive(string prefix) const {
 50         trie_node* curr = root;
 51         for (auto iter = prefix.begin(); iter != prefix.end(); ++iter) {
 52             int ch = *iter;
 53             if (curr) curr = curr->ptrs[ch];
 54             else break;
 55         }
 56         vector<search_result_type> tmp;
 57         if (curr) tmp.push_back(make_pair(curr, prefix));
 58         return tmp;
 59     }
 60 
 61     inline const vector<search_result_type> search_case_insensitive(string prefix) const {
 62         vector<search_result_type> search_result;
 63         string prefix2(prefix);
 64         aux_search(root, prefix, 0, prefix2, search_result);
 65         return search_result;
 66     }
 67     
 68 private:
 69     trie(const trie& t);
 70     const trie& operator=(const trie& t);
 71 
 72     void aux_search(const trie_node* pnode, const string& word, const int i, 
 73         string& prefix, vector<search_result_type>& search_result) const {
 74         if (i == word.size()) {
 75             search_result.push_back(
 76                 make_pair(const_cast<trie_node*>(pnode), prefix));
 77         } 
 78         else if (pnode) {
 79             prefix[i] = word[i];
 80             aux_search(pnode->ptrs[word[i]], word, i+1, prefix, search_result);
 81             
 82             int ch = word[i];
 83             if (isupper(ch)) ch = tolower(ch);
 84             else if (islower(ch)) ch = toupper(ch);
 85             else ch = 0;
 86             if (ch) {
 87                 prefix[i] = ch;
 88                 aux_search(pnode->ptrs[ch], word, i+1, prefix, search_result);
 89             }
 90         }
 91     }
 92     
 93     void release(trie_node * root) {
 94         if (root) {
 95             for (int i = 0; i < n_ascii; i++) {
 96                 release(root->ptrs[i]);
 97             }
 98             delete root;
 99         }
100     }
101 };
102 
103 
104 string trim_head(const string& str) {
105     auto pos = str.find_first_not_of(" \t");
106     if (pos == string::npos) 
107         return str;
108     return str.substr(pos);
109 }
110 
111 string trim_tail(const string& str) {
112     auto pos = str.find_last_not_of(" \t");
113     if (pos == string::npos)
114         return str;
115     return str.substr(0, pos+1);
116 }
117 
118 typedef pair<string, int> value_type;
119 struct my_comp {
120     bool operator()(const value_type& v1, const value_type& v2) const {
121         if (v1.second != v2.second) {
122             return v1.second > v2.second;
123         }
124         else return v1.first < v2.first;
125     }
126 } comp_obj;
127 priority_queue<value_type, vector<value_type>, my_comp> priq;
128 int max_heap_size = 0;
129 
130 void push_heap(const value_type& v) {
131     if (priq.size() < max_heap_size) {
132         priq.push(v);
133     } else if (comp_obj(v, priq.top())) {
134         priq.pop();
135         priq.push(v);
136     }
137 }
138 
139 bool is_leaf(const trie_node* pnode) {
140     if (pnode) {
141         for (int i = 0; i < n_ascii; i++) {
142             if (pnode->ptrs[i] != 0) return false;
143         }
144         return true;
145     }
146     return false;
147 }
148 
149 bool is_valid_freq(const trie_node* pnode) {
150     if (pnode) {
151         return pnode->freq != 0;
152     }
153     return false;
154 }
155 
156 void traverse(const trie_node* pnode, const string& prefix) {
157     if (pnode) {
158         if (is_valid_freq(pnode) || is_leaf(pnode)) 
159             push_heap(make_pair(prefix, pnode->freq));
160         for (int i = 0; i < n_ascii; i++) {
161             if (pnode->ptrs[i]) {
162                 string tmp_str(prefix);
163                 traverse(pnode->ptrs[i], tmp_str.append(1, (char)i));
164             }
165         }
166     }
167 }
168 
169 void list_top_n(const trie& tr, const string& prefix, int n) {
170     auto results = tr.search_case_insensitive(prefix);
171     max_heap_size = n;
172     for (auto result : results) {
173         traverse(result.first, result.second);
174     }
175     
176     vector<value_type> items;
177     int lim = priq.size();
178     for (int i = 0; i < lim; i++) {
179         items.push_back(priq.top());
180         priq.pop();
181     }
182 
183     if (lim) {
184         for (auto iter = items.rbegin(); iter != items.rend(); ++iter) {
185             cout << iter->first << ' ' << iter->second << endl;
186         }
187     } else {
188         cout << "not found !!!" << endl;
189     }
190 }
191 
192 int main() {
193     trie tr; 
194     string line;
195     while (getline(cin, line)) {
196         string words;
197         auto iter = line.begin();
198         for (; iter != line.end(); ++iter) {
199             if (!isdigit(*iter)) {
200                 words.append(1, (char)(*iter));
201             } else break;
202         }
203         if (isdigit(*iter)) {
204             words = trim_tail(words);
205             tr.insert(words, atoi(&(*iter)));
206         } else break;
207     }
208     
209     string prefix;
210     while (getline(cin, prefix)) {
211         cout << prefix << " : "<< endl;
212         prefix = trim_head(prefix);
213         list_top_n(tr, prefix, 10);
214         cout << endl;
215     }
216     return 0;
217 }

测试:

输入:

Baidu   100
Google  100
Google Map  150
Google Play 200
gfsoso  100
google  250
Go 50

G

 

输出:

G : 
google 250
Google Play 200
Google Map 150
Google 100
gfsoso 100
Go 50

 

posted @ 2015-02-19 15:53  william-cheung  阅读(286)  评论(0编辑  收藏  举报