ID3决策树算法
基本概念:
信息熵是信息的一种不确定的程度的度量。假定一个系统s具有概率分布p={pi}(0<=pi<=1),i=1,2,3,4,...,n,则系统s的信息熵定义为
。假设X是一个集合,如果存在一组集合A1,A2,A3,...,An,满足下列条件则称A1-An是集合X的一个划分。
ID3算法使用信息熵作为度量标准,选择信息熵最小的属性作为分类属性,完成决策树的构造,其中属性的熵定义为该属性单个属性值得权熵之和。在生成树的过程中,每个节点只有一个属性值(权熵相同的属性值看成一个属性值)。
树的递归结束条件是,划分的集合是否属于同一类,或者是否达到了所要求的深度,或者某个类的个数达到了一定的阈值。
这是我写的一个ID3算法的例子:
#include<stdio.h>
#include<math.h>
#include<string.h>
#include<stdlib.h>
#define SHORT 0
#define MEDIUM 1
#define TALL 2
#define MAIL 0
#define FEMAIL 1
#define GENDER 0
#define HEIGHT 1
#define KIND 2
#define LEAF -1
typedef struct TNODE
{
int attribute;
int arriv_value;
struct TNODE *child[50];
int childCount;
int classification;
} Node;
int attriCnt[10]={2,6};
int classCnt=2;
int trainingData[100][30];
int testData[100][3];
double Entropy(int *indexArray/*需要统计元组下标*/,int len/*元组的个数*/)
{
/*
1.统计某个属性种类得个数
2.使用log计算出值,返回
*/
double sum=0;
int i,j;
int cnt[10];
memset(cnt,0,sizeof(cnt));
for(j=0;j<classCnt;j++)
{
for(i=0;i<len;i++)
{
if(trainingData[indexArray[i]][KIND]==j/*等于某个属性值*/)
{
cnt[j]++;/*该lei值个数+1*/
}
}
}
for(i=0;i<classCnt;i++)
{
if(cnt[i]==0)continue;
double temp=log(cnt[i]*1.0/len)/log(2);
// printf("cnt: %d\n",cnt[i]);
// printf("log: %lf\n",log(cnt[i]*1.0/len));
// printf("temp: %lf\n",temp);
sum=sum-cnt[i]*1.0/len*temp;
}
return sum;
}
double Grain(int *indexArray,int attri,int len)//每次调用grain的环境可能不一样:indexArray
{
int i,j;
double h;
double hd=Entropy(indexArray,len);
// printf("in grain function,hd:%lf\n",hd);
int subIndexArray[10];
int sublen;
double result=0;
for(i=0;i<attriCnt[attri];i++)
{
sublen=0;
for(j=0;j<len;j++)
{
if(trainingData[indexArray[j]][attri]==i/*如果该属性是某个值*/)
{
subIndexArray[sublen++]=indexArray[j];/*统计该属性值得个数,记录下标存入数组当中以便计算*/
}
}
/*for(j=0;j<sublen;j++)
{
printf("%d\t",subIndexArray[j]);
}printf("\n");*/
h=Entropy(subIndexArray,sublen);//计算熵
//printf("in grain function,h:%lf\n",h);
result=result+sublen*1.0/len*h;
}
result=hd-result;
return result;
}
int toClass(int *chooseIndex,int lines)
{
int i;
int cnt[3];
cnt[0]=cnt[1]=cnt[2]=0;
/* for(i=0;i<lines;i++)
{
printf("chooseIndex: %d\t",chooseIndex[i]);
}printf("\n");*/
for(i=0;i<lines;i++)
{
cnt[trainingData[chooseIndex[i]][KIND]]++;
}
int maxv=-1;
int flag=0;
for(i=0;i<3;i++)
{
if(maxv<cnt[i]){maxv=cnt[i];flag=i;}
}
//printf("maxv: %d\n",maxv);
// printf("flag: %d\n",flag);
return flag;
}
int check_attribute(int *chooseIndex,int len)//检查所有得元组是否都是一类
{
/*
1.扫描所有得元组,如果出现不适同一类得元组,则返回
*/
int i;
for(i=1;i<len;i++)
{
if(trainingData[chooseIndex[i]][KIND]!=trainingData[chooseIndex[i-1]][KIND])
{
return 0;
}
}
return 1;
}
Node *buildTree(int *chooseIndex/*选中的元组*/,int lines/*元组个数*/,int *remain_attribute/*剩下未分类的属性*/,int attriNumber/*属性得个数*/,int arriv_value)
{
//错误:递归结束条件错
int i,j;
// printf("attriNumber: %d\n",attriNumber);
// printf("lines: %d\n",lines);
/*for(i=0;i<lines;i++)
{
printf("chooseIndex: %d\t",chooseIndex[i]);
}printf("\n");*/
if(lines==0)return NULL;
int choose_attribute;
double maxgrain=-1;
int flag=check_attribute(chooseIndex,lines);
if(flag==1)/*属性相同的时候,停止递归*/
{
Node *no=(Node *)malloc(sizeof(Node));
no->attribute=LEAF;
no->childCount=0;
no->arriv_value=arriv_value;
no->classification=toClass(chooseIndex,lines);
for(i=0;i<50;i++)no->child[i]=NULL;
return no;
}
else if(attriNumber==1)
{
choose_attribute=remain_attribute[0];
}
else
{
for(i=0;i<attriNumber;i++)//选中最大得增益值
{
double temp=Grain(chooseIndex,remain_attribute[i],lines);
// printf("temp: %lf\t",temp);
if(temp>maxgrain)
{
maxgrain=temp;
choose_attribute=remain_attribute[i];
}
}
//printf("\n");
}
/*确定剩下得属性*/
int subRemain_attribute[10];
int k=0;
for(i=0;i<attriNumber;i++)//计算未使用得属性
{
if(remain_attribute[i]!=choose_attribute)
{
subRemain_attribute[k++]=remain_attribute[i];
}
}
/*新建节点*/
Node *no=(Node *)malloc(sizeof(Node));
no->attribute=choose_attribute;
no->childCount=attriCnt[choose_attribute];
no->arriv_value=arriv_value;
no->classification=-1;
for(i=0;i<50;i++)no->child[i]=NULL;
for(i=0;i<attriCnt[choose_attribute];i++)
{
int subChooseIndex[100];
int subLines=0;
for(j=0;j<lines;j++)
{
if(trainingData[chooseIndex[j]][choose_attribute]==i)
{
subChooseIndex[subLines++]=chooseIndex[j];
}
}
no->child[i]=buildTree(subChooseIndex,subLines,subRemain_attribute,k,i);
}
return no;
}
void blank(int deep)
{
int i;
for(i=0;i<deep;i++)printf("\t\t");
}
void Triverse(Node *root,int deep)
{
if(root==NULL)return;
int i;
blank(deep);
switch (root->attribute)
{
case GENDER:printf(" classification:gender\n");blank(deep);break;
case HEIGHT:printf("calssification:height\n");blank(deep);break;
case LEAF:printf("leaf arrived\n");blank(deep);break;
default:printf("%d\n",root->attribute);blank(deep);
}
printf("arriv_value: %d\n",root->arriv_value);blank(deep);
printf("childCount: %d\n",root->childCount);blank(deep);
printf("classification: %d\n",root->classification);blank(deep);
printf("------------------------------------------\n");
for(i=0;i<root->childCount;i++)
{
Triverse(root->child[i],deep+1);
}
}
void Classify(int lineNumber,Node *root)
{
if(root==NULL)
{
printf("classify failed!\n");
return;
}
if(root->child[0]==NULL)//如果到达了叶子节点
{
int choice=root->classification;
switch (choice) {
case 0:printf("the training data belongs to Short\n");break;
case 1:printf("the training data belongs to Medium\n");break;
case 2:printf("the training data belongs to Tall\n");break;
default: printf("classify failed!\n");break;
}
return;
}
int classifyAttribute=root->attribute;
int childIndex=testData[lineNumber][classifyAttribute];
Classify(lineNumber,root->child[childIndex]);
}
int main()
{
FILE *fp;
fp=fopen("./data.txt","r");
if(fp==NULL)
{
printf("Can not open file\n");
return 0;
}
char name[10],kind[10],gender[10];
double height;
int lines=0;
while(fscanf(fp,"%s",name)!=EOF)
{
fscanf(fp,"%s",gender);
if(!strcmp(gender,"F"))
{
trainingData[lines][0]=FEMAIL;
}
else trainingData[lines][0]=MAIL;
fscanf(fp,"%lf",&height);
if(height>=1.6&&height<1.7)
{
trainingData[lines][1]=0;
}
else if(height>=1.7&&height<1.8)
{
trainingData[lines][1]=1;
}
else if(height>=1.8&&height<1.9)
{
trainingData[lines][1]=2;
}
else if(height>=1.9&&height<2.0)
{
trainingData[lines][1]=3;
}
else if(height>=2.0&&height<2.1)
{
trainingData[lines][1]=4;
}
else if(height>=2.1&&height<=2.2)
{
trainingData[lines][1]=5;
}
fscanf(fp,"%s",kind);
if(!strcmp(kind,"Short"))
{
trainingData[lines][2]=SHORT;
}
else if(!strcmp(kind,"Medium"))
{
trainingData[lines][2]=MEDIUM;
}
else trainingData[lines][2]=TALL;
lines++;
}
//printf("lines: %d\n",lines);
int i,j;
/* for(i=0;i<lines;i++)
{
for(j=0;j<=2;j++)printf("%d\t",trainingData[i][j]);
printf("\n");
}*/
fclose(fp);fp=NULL;
int index[100],remain_attribute[100];
for(i=0;i<lines;i++)index[i]=i;
for(i=0;i<2;i++)remain_attribute[i]=i;
printf("print the decision tree:\n");
Node *root=buildTree(index,lines,remain_attribute,2,-1);
Triverse(root,0);
printf("The training data is:\n");
fp=fopen("./testData.txt","r");
if(fp==NULL)
{
printf("Can not open the file!\n");
return 0;
}
int testlines=0;
while(fscanf(fp,"%s",name)!=EOF)
{
printf("%s\t",name);
fscanf(fp,"%s",gender);
printf("%s\t",gender);
if(!strcmp(gender,"F"))
{
testData[testlines][0]=FEMAIL;
}
else testData[testlines][0]=MAIL;
fscanf(fp,"%lf",&height);
printf("%lf\n",height);
if(height>=1.6&&height<1.7)
{
testData[testlines][1]=0;
}
else if(height>=1.7&&height<1.8)
{
testData[testlines][1]=1;
}
else if(height>=1.8&&height<1.9)
{
testData[testlines][1]=2;
}
else if(height>=1.9&&height<2.0)
{
testData[testlines][1]=3;
}
else if(height>=2.0&&height<2.1)
{
testData[testlines][1]=4;
}
else if(height>=2.1&&height<=2.2)
{
testData[testlines][1]=5;
}
testlines++;
}
/*for(i=0;i<testlines;i++)
{
for(j=0;j<2;j++)
{
printf("%d\t",testData[i][j]);
}printf("\n");
}*/
Classify(0,root);
}

浙公网安备 33010602011771号