ID3决策树算法

基本概念:

信息熵是信息的一种不确定的程度的度量。假定一个系统s具有概率分布p={pi}(0<=pi<=1),i=1,2,3,4,...,n,则系统s的信息熵定义为。假设X是一个集合,如果存在一组集合A1,A2,A3,...,An,满足下列条件则称A1-An是集合X的一个划分。

ID3算法使用信息熵作为度量标准,选择信息熵最小的属性作为分类属性,完成决策树的构造,其中属性的熵定义为该属性单个属性值得权熵之和。在生成树的过程中,每个节点只有一个属性值(权熵相同的属性值看成一个属性值)。

树的递归结束条件是,划分的集合是否属于同一类,或者是否达到了所要求的深度,或者某个类的个数达到了一定的阈值。

这是我写的一个ID3算法的例子:

#include<stdio.h>
#include<math.h>
#include<string.h>
#include<stdlib.h>
#define SHORT	    0
#define MEDIUM 	1
#define TALL 	    2
#define MAIL 	    0
#define FEMAIL 	1
#define GENDER	0
#define HEIGHT	1
#define KIND 	    2
#define LEAF       -1
typedef struct TNODE
{
	int attribute;
	int arriv_value;
	struct TNODE *child[50];
	int childCount;
	int classification;
} Node;
int attriCnt[10]={2,6};
int classCnt=2;
int trainingData[100][30];
int testData[100][3];
double Entropy(int *indexArray/*需要统计元组下标*/,int len/*元组的个数*/)
{
	/*
		1.统计某个属性种类得个数
		2.使用log计算出值,返回
	*/
	double sum=0;
	int i,j;
	int cnt[10];
	memset(cnt,0,sizeof(cnt));
	for(j=0;j<classCnt;j++)
	{
		for(i=0;i<len;i++)
		{
			if(trainingData[indexArray[i]][KIND]==j/*等于某个属性值*/)
			{
				cnt[j]++;/*该lei值个数+1*/
			}
		}
	}
	for(i=0;i<classCnt;i++)
	{
        if(cnt[i]==0)continue;
		double temp=log(cnt[i]*1.0/len)/log(2);
       // printf("cnt: %d\n",cnt[i]);
       // printf("log: %lf\n",log(cnt[i]*1.0/len));
       // printf("temp: %lf\n",temp);
        sum=sum-cnt[i]*1.0/len*temp;
	}
    return sum;
}
double Grain(int *indexArray,int attri,int len)//每次调用grain的环境可能不一样:indexArray
{
	int i,j;
	double h;
	double hd=Entropy(indexArray,len);
   // printf("in grain function,hd:%lf\n",hd);
	int subIndexArray[10];
	int sublen;
	double result=0;
	for(i=0;i<attriCnt[attri];i++)
	{
		sublen=0;
		for(j=0;j<len;j++)
		{
			if(trainingData[indexArray[j]][attri]==i/*如果该属性是某个值*/)
			{
				subIndexArray[sublen++]=indexArray[j];/*统计该属性值得个数,记录下标存入数组当中以便计算*/
			}
		}
        /*for(j=0;j<sublen;j++)
        {
            printf("%d\t",subIndexArray[j]);
        }printf("\n");*/
		h=Entropy(subIndexArray,sublen);//计算熵
        //printf("in grain function,h:%lf\n",h);
		result=result+sublen*1.0/len*h;
	}
    result=hd-result;
	return result;
}
int toClass(int *chooseIndex,int lines)
{
	int i;
	int cnt[3];
    cnt[0]=cnt[1]=cnt[2]=0;
   /* for(i=0;i<lines;i++)
    {
        printf("chooseIndex: %d\t",chooseIndex[i]);
    }printf("\n");*/
	for(i=0;i<lines;i++)
	{
		cnt[trainingData[chooseIndex[i]][KIND]]++;
	}
	int maxv=-1;
    int flag=0;
    for(i=0;i<3;i++)
	{
		if(maxv<cnt[i]){maxv=cnt[i];flag=i;}
	}
    //printf("maxv: %d\n",maxv);
   // printf("flag: %d\n",flag);
	return flag;
}
int check_attribute(int *chooseIndex,int len)//检查所有得元组是否都是一类
{
    /*
     1.扫描所有得元组,如果出现不适同一类得元组,则返回
    */
    int i;
    for(i=1;i<len;i++)
    {
        if(trainingData[chooseIndex[i]][KIND]!=trainingData[chooseIndex[i-1]][KIND])
        {
            return 0;
        }
    }
    return 1;
}
Node *buildTree(int *chooseIndex/*选中的元组*/,int lines/*元组个数*/,int *remain_attribute/*剩下未分类的属性*/,int attriNumber/*属性得个数*/,int arriv_value)
{
	//错误:递归结束条件错
    int i,j;
   // printf("attriNumber: %d\n",attriNumber);
  //  printf("lines: %d\n",lines);
    /*for(i=0;i<lines;i++)
    {
        printf("chooseIndex: %d\t",chooseIndex[i]);
    }printf("\n");*/
    if(lines==0)return NULL;
    int choose_attribute;
    double maxgrain=-1;
    int flag=check_attribute(chooseIndex,lines);
	if(flag==1)/*属性相同的时候,停止递归*/
	{
		Node *no=(Node *)malloc(sizeof(Node));
		no->attribute=LEAF;
		no->childCount=0;
		no->arriv_value=arriv_value;
		no->classification=toClass(chooseIndex,lines);
		for(i=0;i<50;i++)no->child[i]=NULL;
		return no;
	}
    else if(attriNumber==1)
    {
        choose_attribute=remain_attribute[0];
    }
	else
    {
        for(i=0;i<attriNumber;i++)//选中最大得增益值
        {
            double temp=Grain(chooseIndex,remain_attribute[i],lines);
        //   printf("temp: %lf\t",temp);
            if(temp>maxgrain)
            {
                maxgrain=temp;
                choose_attribute=remain_attribute[i];
            }
        }
        //printf("\n");
    }
	/*确定剩下得属性*/
	int subRemain_attribute[10];
	int k=0;
	for(i=0;i<attriNumber;i++)//计算未使用得属性
	{
		if(remain_attribute[i]!=choose_attribute)
		{
			subRemain_attribute[k++]=remain_attribute[i];
		}
	}
	/*新建节点*/
	Node *no=(Node *)malloc(sizeof(Node));
	no->attribute=choose_attribute;
	no->childCount=attriCnt[choose_attribute];
	no->arriv_value=arriv_value;
    no->classification=-1;
	for(i=0;i<50;i++)no->child[i]=NULL;
	for(i=0;i<attriCnt[choose_attribute];i++)
	{
		int subChooseIndex[100];
		int subLines=0;
		for(j=0;j<lines;j++)
		{
			if(trainingData[chooseIndex[j]][choose_attribute]==i)
			{
				subChooseIndex[subLines++]=chooseIndex[j];
			}
		}
		no->child[i]=buildTree(subChooseIndex,subLines,subRemain_attribute,k,i);
	}
    return no;
}
void blank(int deep)
{
    int i;
    for(i=0;i<deep;i++)printf("\t\t");
}
void Triverse(Node *root,int deep)
{
    if(root==NULL)return;
    int i;
    blank(deep);
    switch (root->attribute)
    {
        case GENDER:printf(" classification:gender\n");blank(deep);break;
        case HEIGHT:printf("calssification:height\n");blank(deep);break;
        case LEAF:printf("leaf arrived\n");blank(deep);break;
        default:printf("%d\n",root->attribute);blank(deep);
    }
    printf("arriv_value: %d\n",root->arriv_value);blank(deep);
    printf("childCount: %d\n",root->childCount);blank(deep);
    printf("classification: %d\n",root->classification);blank(deep);
    printf("------------------------------------------\n");
    for(i=0;i<root->childCount;i++)
    {
        Triverse(root->child[i],deep+1);
    }
}
void Classify(int lineNumber,Node *root)
{
    if(root==NULL)
    {
        printf("classify failed!\n");
        return;
    }
    if(root->child[0]==NULL)//如果到达了叶子节点
    {
        int choice=root->classification;
        switch (choice) {
            case 0:printf("the training data belongs to Short\n");break;
            case 1:printf("the training data belongs to Medium\n");break;
            case 2:printf("the training data belongs to Tall\n");break;
            default: printf("classify failed!\n");break;
        }
        return;
    }
    int classifyAttribute=root->attribute;
    int childIndex=testData[lineNumber][classifyAttribute];
    Classify(lineNumber,root->child[childIndex]);
}
int main()
{
    FILE *fp;
	fp=fopen("./data.txt","r");
	if(fp==NULL)
	{
		printf("Can not open file\n");
		return 0;
	}
	char name[10],kind[10],gender[10];
    double height;
	int lines=0;
	while(fscanf(fp,"%s",name)!=EOF)
	{
        fscanf(fp,"%s",gender);
		if(!strcmp(gender,"F"))
		{
			trainingData[lines][0]=FEMAIL;
		}
		else trainingData[lines][0]=MAIL;
		fscanf(fp,"%lf",&height);
        if(height>=1.6&&height<1.7)
        {
            trainingData[lines][1]=0;
        }
        else if(height>=1.7&&height<1.8)
        {
            trainingData[lines][1]=1;
        }
        else if(height>=1.8&&height<1.9)
        {
            trainingData[lines][1]=2;
        }
        else if(height>=1.9&&height<2.0)
        {
            trainingData[lines][1]=3;
        }
        else if(height>=2.0&&height<2.1)
        {
            trainingData[lines][1]=4;
        }
        else if(height>=2.1&&height<=2.2)
        {
            trainingData[lines][1]=5;
        }
        fscanf(fp,"%s",kind);
        if(!strcmp(kind,"Short"))
        {
            trainingData[lines][2]=SHORT;
        }
        else if(!strcmp(kind,"Medium"))
        {
            trainingData[lines][2]=MEDIUM;
        }
        else trainingData[lines][2]=TALL;
        lines++;
	}
    //printf("lines: %d\n",lines);
    int i,j;
   /* for(i=0;i<lines;i++)
    {
        for(j=0;j<=2;j++)printf("%d\t",trainingData[i][j]);
        printf("\n");
    }*/
    fclose(fp);fp=NULL;
    int index[100],remain_attribute[100];
    for(i=0;i<lines;i++)index[i]=i;
    for(i=0;i<2;i++)remain_attribute[i]=i;
    printf("print the decision tree:\n");
    Node *root=buildTree(index,lines,remain_attribute,2,-1);
    Triverse(root,0);
    printf("The training data is:\n");
    fp=fopen("./testData.txt","r");
    if(fp==NULL)
    {
        printf("Can not open the file!\n");
        return 0;
    }
    int testlines=0;
    while(fscanf(fp,"%s",name)!=EOF)
	{
        printf("%s\t",name);
        fscanf(fp,"%s",gender);
        printf("%s\t",gender);
		if(!strcmp(gender,"F"))
		{
			testData[testlines][0]=FEMAIL;
		}
		else testData[testlines][0]=MAIL;
		fscanf(fp,"%lf",&height);
        printf("%lf\n",height);
        if(height>=1.6&&height<1.7)
        {
            testData[testlines][1]=0;
        }
        else if(height>=1.7&&height<1.8)
        {
            testData[testlines][1]=1;
        }
        else if(height>=1.8&&height<1.9)
        {
            testData[testlines][1]=2;
        }
        else if(height>=1.9&&height<2.0)
        {
            testData[testlines][1]=3;
        }
        else if(height>=2.0&&height<2.1)
        {
            testData[testlines][1]=4;
        }
        else if(height>=2.1&&height<=2.2)
        {
            testData[testlines][1]=5;
        }
        testlines++;
	}
    /*for(i=0;i<testlines;i++)
    {
        for(j=0;j<2;j++)
        {
            printf("%d\t",testData[i][j]);
        }printf("\n");
    }*/
    Classify(0,root);
}

 

posted @ 2013-12-30 23:39  湖心北斗  阅读(253)  评论(0编辑  收藏  举报