机器学习算法之C4.5(C语言实现)

C4.5的学习

写的很简单,可以看下面的博客资源

C4.5算法相关资料

C4.5算法主要基于ID3算法改进,ID3算法利用信息增益构建决策树,而C4.5利用信息增益率来构建决策树,并且加入了树枝修剪来排除过拟合。

以下公式用于计算所需信息。

ID3算法利用信息增益构建决策树,信息增益越大,则考虑以此属性进行分裂

C4.5算法考虑到信息增益偏向于大量值的属性,因此考虑利用信息增益率来选择属性作为分裂点。

SplitInfoA(D) = -∑(Di/D)*log2(Di/D); 分裂信息GainRatioA(D) = Gain(A)/SplitInfoA(D); 信息增益率

以上C4.5算法需要用到的公式。

后面使用UCI数据库中的数据,用C语言实现,(代码全部为原创,可能有点乱,后面根据情况可能会进一步修改,将得到的决策树用于学习,可以得到70%左右的正确率)

根据结果小节:

1.程序中采用的树的深度限制树过于复杂(可以采用其他思路进行修剪)。

2. 训练数据集采用了前1600(总共1728)个样本进行,更好的想法应该是随机选取数据。

只有两个代码文件,没有将函数分开,功能都在main.h实现,具体的用例在main.cpp中。

<main.h>

<pre name="code" class="cpp">/****************************DATA: Car Evaluation Database(from UCI)决策树学习算法C实现(彭杰)Copyright 2015/3/21 owner by pengjieAll rights reserved*******************************/#ifndef MAIN_H#define MAIN_H#include "stdio.h"#include "stdlib.h"#include <map>#include <iostream>#include <string>#include <string.h>#include <vector>#include <math.h>#define maxClass 10#define maxDeep 3using namespace std;typedef map<string, char> StringCharMap;int locp = 0;//tree point location in the arraystruct attrPoint{int attrName;char attrValue;};struct DesTreePoint{attrPoint prepos;//previous pointint curpos;//current pointint deep;//current deepint nextPointNum;//attrPoint nextattr[maxClass];//next attrname,0:none next pointchar label;//classifyint loca;//int preloca;int nextloca[maxClass];};vector<DesTreePoint> treeQueue;void print(DesTreePoint point){printf("prepos–name,value:%d,%c\n",point.prepos.attrName,point.prepos.attrValue);printf("curpos: %d\n",point.curpos);printf("deep: %d\n",point.deep);printf("nextPointNum: %d\n",point.nextPointNum);printf("loca: %d\n",point.loca);printf("preloca:%d\n",point.preloca);}int ReadData(char **data, int attriSome[], StringCharMap *mapAttr, char *file,const int dataline,const int attrinum){FILE *pFile;char buf[256];pFile = fopen(file,"rt");if(pFile==NULL){printf("the data file is not existing: %s\n", file);return -1;}int row = 0; //data lineint cloumn = 0; //data attributechar delim[] = ",";//data delimiterstring tmpdata;//data cachewhile(!feof(pFile)&&row<dataline){fgets(buf,256,pFile);//printf("%s\t %d\n",buf,row);/*printf("%d-%c\n",strlen(buf),buf[strlen(buf)-1]);*//*buf[strlen(buf)-1]=='\n';*/if( buf[strlen(buf)-1]=='\n' ){buf[strlen(buf)-1]='\0';}for(cloumn=0;cloumn<attrinum;++cloumn ){if( cloumn==0 ){tmpdata = strtok(buf,delim);//tmpdata[strlen(tmpdata)] = '\0';//printf("%s\t",tmpdata.c_str());data[row][cloumn] = mapAttr[cloumn][tmpdata];}else{tmpdata = strtok(NULL,delim);//tmpdata[strlen(tmpdata)] = '\0';//printf("%s,%d\t",tmpdata.c_str(),strlen(tmpdata.c_str()));data[row][cloumn] = mapAttr[cloumn][tmpdata];}}//printf("\n");++row;}return 1;}double GainRatio( char **data,const int datasize, const int attrinum,int deep,attrPoint *attrp,int *attriSome,int attrname,int &classfiedLabel ){int labelnum = attriSome[attrname];//attribute num named 'attrname'int classifynum = attriSome[attrinum-1];//classify numint **attrFInfoDA = new int*[labelnum];for(int i=0;i<labelnum;++i)attrFInfoDA[i] = new int[classifynum+1];//for InfoDAint *attrFSplitInfo =new int[labelnum+1];//for SplitInfoint *classify = new int[classifynum+1];//for InfoDdouble infoDA = 0.0;double splitInfo = 0.0;double infoD = 0.0;//initialfor(int i=0;i<labelnum;++i)for(int j=0;j<=classifynum;++j)attrFInfoDA[i][j] = 0;for(int i=0;i<=labelnum;++i)attrFSplitInfo[i] = 0;for(int i=0;i<=classifynum;++i)classify[i] = 0;//get dataif( deep>0 ){for(int i=0;i<datasize;++i ){int flag = 0;for(int j=0;j<deep;++j){int tmpattr = attrp[j].attrName;if( data[i][tmpattr]==attrp[j].attrValue ){++flag;continue;}elsej = deep;}if(flag==deep){int attrlabel = data[i][attrname]-48;int classifylabel = data[i][attrinum-1]-48;++attrFInfoDA[attrlabel][classifylabel];++attrFSplitInfo[attrlabel];++classify[classifylabel];}}}else{for(int i=0;i<datasize;++i ){int attrlabel = data[i][attrname]-48;int classifylabel = data[i][attrinum-1]-48;++attrFInfoDA[attrlabel][classifylabel];++attrFSplitInfo[attrlabel];++classify[classifylabel];}}//calculate//printf("classfy: \n");for(int i=0; i<classifynum; ++i){classify[classifynum] += classify[i];}//printf(" %d\n",classify[classifynum] );//printf("attrFInfoDA: \n");for(int i=0;i<labelnum;++i){for(int j=0;j<classifynum;++j){attrFInfoDA[i][classifynum] += attrFInfoDA[i][j];}//printf(" %d\n",attrFInfoDA[i][classifynum] );}//printf("attrFSplitInfo:\n");for(int i=0; i<labelnum; ++i){attrFSplitInfo[labelnum] += attrFSplitInfo[i];//printf(" %d\n",attrFSplitInfo[i] );}//printf(" %d\n",attrFSplitInfo[labelnum] );//infoDdouble maxpi = 0.0;int maxindex = 0;for(int i=0; i<classifynum; ++i){double pi = double(classify[i])/classify[classifynum];//printf(" pi %d: %f\n",i,pi);if(pi>maxpi){maxpi = pi;maxindex = i;}if(pi<0.000001)infoD += 0.0;elseinfoD +=(-1*pi*log(pi)/log(2.0));}if (fabs(infoD)<0.0000001||maxpi>0.95){classfiedLabel = maxindex;return (0.0);}else{classfiedLabel = maxindex;}//printf("infoD: %f \n",infoD );//infoDAdouble infoDj = 0.0;for( int i=0;i<labelnum;++i ){for( int j=0;j<classifynum;++j ){double pj = double(attrFInfoDA[i][j])/attrFInfoDA[i][classifynum];//printf(" pj_%d_%d: %f\n",i,j, pj);if(pj<0.000001)infoDj = 0.0;elseinfoDj += (-1*pj*log(pj)/log(2.0));}infoDA += double(attrFSplitInfo[i])/attrFSplitInfo[labelnum]*infoDj;//printf(" infoDj_%d: %f\n",i, infoDj);infoDj = 0.0;}//printf(" infoDA: %f\n", infoDA);//splitInfofor( int i=0;i<labelnum;++i){double ps = double(attrFSplitInfo[i])/attrFSplitInfo[labelnum];if(ps<0.000001)splitInfo += 0.0;elsesplitInfo += (-ps*log(ps)/log(2.0));}//printf(" splitInfo: %f\n", splitInfo);return ( (infoD-infoDA)/splitInfo );}void CreateDecisionTree( DesTreePoint *tree,const int datasize, const int attrinum, char **data,int *attriSome ){double *ration = new double[attrinum] ;bool *lockAttr = new bool[attrinum];//lock the already check attribute column: 0,non-lock; 1,lock;int maxRation = 0;double maxV = -1;int classfiedLabel = -1;//if it's leaf point,classfiedLabel get the final classfied label//get root pointfor(int i=0;i<(attrinum-1);++i){ration[i] = GainRatio(data,datasize,attrinum,0,NULL,attriSome,i,classfiedLabel);//printf("%d: %f \n",i,ration[i]);}for(int i=1;i<(attrinum-1);++i){if(ration[i]>ration[maxRation])maxRation = i;}tree[0].prepos.attrName = -1;tree[0].curpos = maxRation;tree[0].deep = 0;tree[0].loca = 0;tree[0].label = '#';tree[0].preloca = -1;tree[0].nextPointNum = attriSome[maxRation];//print(tree[0]);for(int i=0;i<tree[0].nextPointNum;++i){++locp;tree[0].nextloca[i] = locp;tree[locp].prepos.attrName = tree[0].curpos;tree[locp].loca = locp;tree[locp].prepos.attrValue = i+48;tree[locp].deep = tree[0].deep+1;tree[locp].preloca = tree[0].loca;tree[locp].label = '#';treeQueue.push_back(tree[locp]);}attrPoint tmpTree[maxClass];DesTreePoint tmpPoint;DesTreePoint staPoint;attrPoint tmpAttr;int tmpp = 0;while( !treeQueue.empty() ){//initialfor( int i=0;i<attrinum;++i )lockAttr[i] = 0;tmpPoint = treeQueue.back();staPoint = tmpPoint;//print(tmpPoint);treeQueue.pop_back();tmpAttr = staPoint.prepos;tmpp = 0;while( tmpAttr.attrName!=-1 ){tmpTree[tmpp] = tmpAttr;lockAttr[tmpAttr.attrName] = 1;++tmpp;staPoint = tree[staPoint.preloca];tmpAttr = tree[staPoint.loca].prepos;//print(tmpPoint);}//printf(" equal to deep %d\n",tmpp );for(int i=0;i<(attrinum-1);++i)ration[i] = 0;bool isleaf = 0;for(int i=0;i<(attrinum-1);++i){if( lockAttr[i]==1 )continue;//printf("%d: %f \n",i,ration[i]);ration[i] = GainRatio(data,datasize,attrinum,tmpPoint.deep,tmpTree,attriSome,i,classfiedLabel);//printf("%d: %f \n",i,ration[i]);if( ration[i]<0.000001 || tmpPoint.deep>maxDeep){tree[tmpPoint.loca].label = classfiedLabel+48;isleaf = 1;i = attrinum;}}classfiedLabel = -1; //go to the initial statemaxRation = 0;if(!isleaf){for(int i=1;i<(attrinum-1);++i){if( lockAttr[i]==1 )continue;if(ration[i]>maxV){maxRation = i;maxV = ration[i];}}tree[tmpPoint.loca].curpos = maxRation;tree[tmpPoint.loca].nextPointNum = attriSome[maxRation];//tree[tmpPoint.loca].label = '#';for(int i=0; i<tree[tmpPoint.loca].nextPointNum;++i){++locp;tree[tmpPoint.loca].nextloca[i] = locp;tree[locp].prepos.attrName = tree[tmpPoint.loca].curpos;tree[locp].deep = tree[tmpPoint.loca].deep+1;tree[locp].prepos.attrValue = i+48;tree[locp].loca = locp;tree[locp].label = '#';tree[locp].preloca = tree[tmpPoint.loca].loca;treeQueue.push_back(tree[locp]);}}//if isleaf}//whiledelete []ration;}bool predict(DesTreePoint *tree,char *data, const int attrinum ){int locp = 0;int tmploc = 0;char label = '#';DesTreePoint curPoint;curPoint = tree[0];if( tree[0].label!='#' ){if(tree[0].label==data[attrinum-1])return 1;elsereturn 0;}while( curPoint.label=='#' ){locp = curPoint.loca;for(int i=0;i<tree[locp].nextPointNum;++i){tmploc = tree[locp].nextloca[i];int attrName = tree[tmploc].prepos.attrName;if(data[attrName]==tree[tmploc].prepos.attrValue)curPoint = tree[tmploc];}}if(curPoint.label == data[attrinum-1])return 1;elsereturn 0;}#endif如果你不出去走走,你就会以为这就是世界。

机器学习算法之C4.5(C语言实现)

相关文章:

你感兴趣的文章:

标签云: