[C++]简单即敲即得系统(基于Trie树)

这几日实现了一个基于Trie树的简单即敲即得系统,数据库为MySql,测试数据为380万条记录。输入数据为utf-8格式,在Trie树节点中保存的是wchar类型的字符,因为一个汉字需要2个字节的长度,如果使用char类型就必须将一个汉字拆成两个节点来保存,这样以后要做模糊搜索什么的就比较麻烦。

工程共包含5个代码文件,分别如下:


Char2Wchar.h : 实现char到wchar_t的转换

代码

 

 Trie.h : 定义trie树的结构

代码
#ifndef __TRIE__
#define __TRIE__

#include <string.h>
#include <stdlib.h>

// structure of trie node
struct TrieNode
{
    TrieNode():kidlen(0),buflen(0),kids(NULL),invlen(0),ibflen(0),invList(NULL){}
    
// parameter list
    wchar_t key;
    
int kidlen,buflen;
    TrieNode* kids;
    
int invlen,ibflen;
    unsigned* invList;

    TrieNode* Insert(wchar_t w)
    {
        
for (int i=0;i<kidlen;i++)
            
if(kids[i].key==w)
                
return &kids[i];
        
//check memory
        if (kidlen==buflen)
        {
            buflen=buflen*2+1;
            
int tlen = sizeof(TrieNode)*buflen;
            kids=(TrieNode*)realloc(kids,tlen);
        }
        
//add child node
        new (&kids[kidlen]) TrieNode();
        kids[kidlen].key=w;
        
return &kids[kidlen++];
    }

    
// add record id to inverted list
    void add2Invlist(unsigned id)
    {
        
if (invlen>0)
            
if (invList[invlen-1]==id)return;
        
if (invlen==ibflen)
        {
            ibflen=ibflen*2+1;
            size_t nlen=sizeof(unsigned)*ibflen;
            invList=(unsigned*)realloc(invList,nlen);
        }
        invList[invlen++]=id;
    }

    
// search function
    TrieNode* Search(wchar_t w)
    {
        
for (int i=0;i<kidlen;i++)
            
if(kids[i].key==w)
                
return &kids[i];
        
return NULL;
    }
};


#endif


TypeAheadSearch.h :定义搜索结构

代码
#ifndef H_TYPEAHEADSEARCH
#define H_TYPEAHEADSEARCH

#include <vector>
using namespace std;

class TypeAheadSearch 
{
public:
    
bool createIndex(const char* user, const char* passwd, const char* host, const char* db, const char* table);
    
bool search(const char *query, const int topk, vector<unsigned>& results);
};

#endif

 

TypeAheadSearch.cpp : 实现索引建立及搜索

代码
#include "TypeAheadSearch.h"
#include "Trie.h"
#include "Char2W.h"
#include <Windows.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <algorithm>
#include <mysql.h> 
#include <ctime>
#include <iostream>
using namespace std;

unsigned port=3306;  //server port    
unsigned makeId=1;
MYSQL myCont;
MYSQL_RES *result;
MYSQL_ROW sql_row;
MYSQL_FIELD *fd;
time_t Start,End;
TrieNode firstfloor[65535];  //children list of root

// switch char* to wchar_t*
wchar_t* ConvChar(char*ch)
{
    
int len=strlen(ch);
    wchar_t* temp=(wchar_t*)malloc((len+1)*sizeof(wchar_t));
    SWIchar2wchar((const unsigned char*)ch,temp,len);
    
return temp;
}

// print trie tree
void showIndex(TrieNode* node,char* span)
{
    printf("%sinvlist: ",span);
    
for(int i=0;i<node->invlen;i++)
        printf("%u ",node->invList[i]);
    printf("\n");
    
if (node->kidlen>0)
    {
        
for (int i=0;i<node->kidlen;i++)
        {
            
char sp[50];
            sprintf(sp,"%s  ",span);
            showIndex(&(node->kids[i]),sp);
        }
    }
}

// create index function
void MakeIndex(MYSQL_RES *result)
{
    
int fn=mysql_num_fields(result);
    unsigned count=0;
    TrieNode* tmpnode=NULL;
    wchar_t* word=NULL;
    
char* ch=NULL;
    
int i,j;
    
while(sql_row=mysql_fetch_row(result))
    {
        
for (i=0;i<fn;i++// visit fields of record
        {
            
if ((int)*sql_row[i]==0continue// field is NULL
            ch=strtok(sql_row[i]," ");         // split item by space
            if (ch)
            {
                word=ConvChar(ch);
                tmpnode=&firstfloor[word[0]];
                
for (j=1;j<wcslen(word);j++)
                {
                    tmpnode=tmpnode->Insert(word[j]); // insert word
                }
                tmpnode->add2Invlist(makeId); // add recordid to inverted list
                while(ch=strtok(NULL," "))
                {
                    word=ConvChar(ch);
                    tmpnode=&firstfloor[word[0]];
                    
for (j=1;j<wcslen(word);j++)
                    {
                        tmpnode=tmpnode->Insert(word[j]);
                    }
                    tmpnode->add2Invlist(makeId);
                }
            }
        }
        makeId++;
    }
}

// create index procedure
bool TypeAheadSearch::createIndex(const char* user, const char* passwd, const char* host, const char* db, const char* table)
{
    Start=clock();
    mysql_init(&myCont);
    
if(mysql_real_connect(&myCont,host,user,passwd,db,port,NULL,0)) //connect to mysql
    {
        printf("Connect to DataBase succeed!\n");
        mysql_set_character_set(&myCont,"UTF8");
        
char sql[100];
        sprintf(sql,"select * from %s",table);
        printf("making index..\n");
        memset(firstfloor,0,sizeof(firstfloor));
        
int res = mysql_query(&myCont,sql);
        
if(!res) // query succeed
        {
            result=mysql_use_result(&myCont);
            
if(result)
            {
                MakeIndex(result);
            }
        }
        
else
        {
            printf("Query failed!\n");
            
return false;
        }
        mysql_free_result(result);
        End=clock();
        
double utime=(double)(End-Start)/CLOCKS_PER_SEC;
        printf("makeid = %u\n",makeId);
        printf("makeIndex succeed!\ntime used: %lf seconds\n\n",utime);
    }
    
else 
    {
        printf("Connect to DataBase failed!\n");
        
return false;
    }
    mysql_free_result(result);
    mysql_close(&myCont);
    
return true;    
}

// making search results
void MakeResult(TrieNode*node,vector<unsigned>& results)
{
    
for (int i=0;i<node->kidlen;i++)
        MakeResult(&node->kids[i],results);
    
for (int i=0;i<node->invlen;i++)
        results.push_back(node->invList[i]);
}

// serach procedure
bool TypeAheadSearch::search(const char *querys, const int topk, vector<unsigned>& results)
{
    vector<unsigned> rts,bing;
    Start=clock();
    
char query[65535];
    strcpy(query,querys);
    
char*ch=strtok((char*)query," "); // split query by space
    if (ch)
    {
        wchar_t* word=ConvChar(ch);
        TrieNode* temp=NULL;
        temp=&firstfloor[word[0]];
        
if (temp==NULL) return true;
        
for (int i=1;i<wcslen(word);i++)
        {
            temp=temp->Search(word[i]);
            
if (temp==NULL)
                
return true;
        }
        MakeResult(temp,results);
        sort(results.begin(), results.end());
        vector<unsigned>::iterator iter = unique(results.begin(), results.end());
        results.erase(iter, results.end()); //unique items of results

        
while(ch=strtok(NULL," "))
        {
            wchar_t* word=ConvChar(ch);
            TrieNode* temp=NULL;
            temp=&firstfloor[word[0]];
            
if (temp==NULL) return true;
            
for (int i=1;i<wcslen(word);i++)
            {
                temp=temp->Search(word[i]);
                
if (temp==NULL)
                    
return true;
            }
            MakeResult(temp,rts);
            sort(rts.begin(), rts.end());
            vector<unsigned>::iterator it = unique(rts.begin(), rts.end());
            rts.erase(it, rts.end());
            set_intersection(results.begin(),results.end(),rts.begin(),rts.end(),back_inserter(bing));
            results=bing; // intersection of two key words
            rts.clear();
            bing.clear();
        }
        End=clock();
        
double utime=(double)(End-Start)/CLOCKS_PER_SEC;
        printf("search succeed!\ntime used: %lf seconds\n",utime);
    }
    
return true;
}

 

TestCase.cpp : 测试样例

代码
#include "TypeAheadSearch.h"
#include <iostream>
#include <vector>
using namespace std;

const char user[] = "root";                // username
const char pswd[] = "root";                // password
const char host[] = "localhost";           // or"127.0.0.1"
const char db[]   = "database_name";       // database
const char table[] = "table_name";         // database

void testCase(char* query)
{
   
// create index
    TypeAheadSearch* tas = new TypeAheadSearch();
    
if (!tas->createIndex(user,pswd,host,db,table)) printf("Create Index Error!\n");

    
int topk=100// max number of result
    vector<unsigned int> results;
    
if (tas->search(query,topk,results)) 
    {
        printf("number of results: %d\n",results.size());
        
if (topk>results.size()) 
            topk=results.size();
        
for (int i=0;i<topk;i++)
                printf("%u\n",results[i]);
    }
    
else 
        printf("Search Error!\n");
}

int main()
{
    testCase("querys");
     
return 0;
}

 

测试结果:

 

 

posted @ 2009-12-01 10:28  lovebread  阅读(1582)  评论(0编辑  收藏  举报