很抱歉又让你们等了两个月才更新文章,最近这段时间一直在忙着准备出国的材料,改文书、开证明基本上没课的时间都在旅顺校区和本部校区之间乱跑······不过看着自己的成绩单和几十张证书也是非常欣慰的,我感觉我的大学生活还是非常充实的,说说我的大学前三年是怎么过的吧,大一大二那会儿闲着没事的时候搞搞比赛,到了大三就去东软实训了一个学期我们实训从9月开始到12月就结束了因此要比在学校上课提前结束1个月左右吧,因此那年的冬天我没有选择立刻回家而是在大连这边找了一个小公司去实习了大概待了三个月不过这三个月确实收获颇丰因为在小公司遇到的人都非常好他们都很愿意分享技术而且还告诉我应该掌握怎样的前后端技术,Spring基本上是在公司才接触到,在学校那时如果单纯的说,学校课程基本没接触过框架,我自学的话也只是了解个皮毛基本上没有什么实际的开发来锻炼。之后就到了大三下半学期,这个学期课很多没什么自由的时间所以就基本上主要在忙着学校的课程和雅思考试了。到这我的大学前三年基本上就介绍完了,下面开始正题喔!

决策树(Decision Trees)

如果学过数据结构的话简单说其实它就是我们所接触过的树的实际应用,提起数据结构中的树我们一定记得树有叶子节点和根,同样决策树也是具备这些简单的元素。决策树主要优势就在于数据形式非常容易理解。决策树的一个重要任务是为了理解数据中所蕴含的知识信息,因此决策树可以使用不熟悉的数据集合,并从中提取出一系列规则,这些机器根据数据集创建规则的过程,就是机器学习的过程。专家系统中经常使用决策树,而且决策树给出结果往在可以匹敌在当前领城其有几十年工作验的人类。

决策树的样子就是下图这样的:

 

决策树属性:

决策树属性
优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征值的数据
缺点:可能会产生过度匹配的问题
适用数据类型:数值型和标称型

不过在了解决策树之前我们需要了解一些在数学上是如何使用信息论划分数据集,在构造决策树时,我们需要解决的第一个问题就是,当前数据集上哪个特征在划分数据类时起决定性作用。为了找到决定性的特征,划分出最好的结果,我们必须评估每个特征。完成测试之后,原始数据集就被划分为几个数据子集。这些数据子集会分布在第一个决策点的所有分支上。如果某个分支下的数据属于同一类型,则正确地划分数据分类,无需进一步对数据集进行分割。如果数据子集丙的数据不属于同一类型则需要重复划分数据子集的过程。划分数据子集的算法和划分原始数集的方法和同,直到所有具有相同类型的数据均在一个数据子集内。

因此我们需要进一步了解算法是如何划分数据集的:

决策树的一般流程
(1)收集数据:可以使用任何方法。
(2)准备数据:树构造算法只适用于标称型数据,因此数值型教据必须离散化。
(3)分析较据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期。
(4)训练算法构造树的数据结构。
(5)测试算法:使用经验树计算错误率。
(6)使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据
的内在含义。

信息增益(Information Gain)

划分数据的原则是:将无序的数据变得更加有序,组织杂乱无章的数据的一种方法就是使用信息论度量信息。在划分数据集之前之后信息发生的变化称为信息增益,知道如何计算信息增益我们就可以计算每个特征值划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择。集合信息的度量方式称为香农熵(Entropy),这个名字来源于信息论之父 Claude Elwood Shannon

熵定义为信息的期望值,在明白这个概念之前我们需要知道信息的定义。如果待分类的事务可能存在多个分类之中,则符号xi的信息定义为:

其中pxi是选择该分类的概率。

为了计算熵,我们需要计算所有类别所有的可能值包含的信息期望值,通过下面这个公式得到:

其中n是分类的数目。

计算信息熵:

def calcShannonEnt(dataSet):
    '''
    计算信息增益(香农熵)
    香农熵越高表明混合的数据越多
    :param dataSet: 
    :return: 香农熵
    '''
    numEntries=len(dataSet)
    labelCounts={}
    #为所有可能分类创建字典
    for featVec in dataSet :
        currentLabel=featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel]=0
        labelCounts[currentLabel]+=1
    shannonEnt=0.0
    for key in labelCounts:
        prob=float(labelCounts[key])/numEntries
        #以2为底求对数
        shannonEnt-=prob*log(prob,2)
    return shannonEnt

numEntries用来存放数据集中实例的总数,然后创建一个字典,它的键值是最后一列的数值,如果当前键值不存在则扩展字典并将当前键值加入字典。每个键值都记录了当前类别出现的次数,最后使用所有类别标签的发生频率计算类别出现的概率,用这个概率可以计算香农熵。

我们可以简单测试一下该算法,这里有一张数据表其中包含5个海洋动物,特征包括:不浮出水面是否可以生存,以及是否含有脚蹼,我们可以通过这些值将动物分成两类:鱼类和非鱼类。

海洋生物数据表
不浮出水面是否可以生存是否含有脚蹼属于鱼类

测试代码如下:

def createDataSet():
    dataSet=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
    labels=['no surfacing','flippers']

    return  dataSet,labels
mData,mLabels=createDataSet()
print calcShannonEnt(mData)

输出结果为:0.970950594455

注意:熵越高表明混合的数据也越多。

划分数据集

按照给定的特征划分数据集

def splitDataSet(dataSet,axis,value):
    '''
    按照特征划分数据集
    :param dataSet:带划分数据集 
    :param axis: 划分数据的特征
    :param values: 需要返回的特征值
    :return: 划分好的数据集
    '''
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec=featVec[:axis]
            #这里必须用extend,因为使用append添加后添加的元素也是一个列表
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

需要注意的是:Python语言不用考虑内存分配问题,Python语言在函数中传递的是列表的引用,在函数内部对列表对象的修改,将会影响该列表对象的整个生存周期。为了消除该影响,我们需要再函数开头声明一个新列表对象。还有Python自带的entend()和append()方法功能类似但是在处理多个列表时处理结果并不相同!

接下来我们需要遍历整个数据集,循环计算香农熵,找到最好的特征划分方式。

def chooseBestFeatureToSplit(dataSet):
    numFeatures=len(dataSet[0])-1
    baseEntropy=calcShannonEnt(dataSet)
    bestInfoGain=0.0
    bestFeature=-1
    for i in range(numFeatures):
        featList=[example[i] for example in dataSet]
        uniqueVals=set(featList)
        newEntropy=0.0
        for value in uniqueVals:
            subDataSet=splitDataSet(dataSet,i,value)

            prob=len(subDataSet)/float(len(dataSet))

            newEntropy += prob*calcShannonEnt(subDataSet)
        infoGain=baseEntropy-newEntropy
        if(infoGain>bestInfoGain):
            bestInfoGain=infoGain
            bestFeature=i
    return bestFeature

需要注意的是在函数中调用的数据需要满足一定的要求:第一个要求是,数据必须是一种由列表元素构成的列表,而且所有的列表元素都要具有相同的数据长度;第二个要求是,数据的最后一列或者每个实例的最后一个元素时当前实例的类别标签。

测试一下我们写出的函数

mData,mLabels=createDataSet()
print chooseBestFeatureToSplit(mData)

输出将是0,表示第0个特征是最好的用于划分数据集特征,回到上面海洋动物特征表中,也就是说第一个特征值是1的放在一个组,特征值是0的放在另一组,则分类结果出现两个属于鱼类,一个属于非鱼类。

递归创建树

def createTree(dataSet,labels):
    classList=[example[-1] for example in dataSet]
    #类别相同则停止分类
    if classList.count(classList[0])==len(classList):
        return classList[0]
    #遍历完所有特征时返回出现次数最多的类别
    if len(dataSet[0])==1:
        return majorityCnt(classList)
    bestFeat=chooseBestFeatureToSplit(dataSet)
    bestFeatLabel=labels[bestFeat]
    myTree={bestFeatLabel:{}}
    #得到列表包含的所有属性值
    del(labels[bestFeat])
    featValues=[example[bestFeat] for example in dataSet]
    uniqueVals=set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
    return myTree

递归函数的第一个停止条件是所有的类标签完全相同,则直接返回该标签,第二个停止条件是用完了所有的特征,仍然不能将数据划分为仅包含唯一类别的分组,由于无法简单的返回一个唯一的标签,这里使用出现次数最多的类标签作为返回值。

使用Matplotlib绘制树形图

为了节省时间,我们定义如下函数来输出预选存储的树信息,避免每次测试代码都要从数据中创建树

def retrieveTree(i):
    listOfTrees=[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head':{0:'no', 1: 'yes'}},1:'no'}}}}]
    return listOfTrees[i]

1.获取叶节点的数目和树的深度

def getNumLeafs(myTree):
    numLeafs=0
    firstStr=myTree.keys()[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            numLeafs+=getNumLeafs(secondDict[key])
        else:
            numLeafs+=1
    return numLeafs
def getTreeDepth(myTree):
    maxDepth=0
    firstStr=myTree.keys()[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth=1+getTreeDepth(secondDict[key])
        else:thisDepth=1
        if thisDepth>maxDepth:maxDepth=thisDepth
    return maxDepth

2.绘制节点间的文字信息

def plotMidText(cntrPt,parentPt,txtString):
    '''
    在父子节点间填充文本信息
    :param cntrPt: 
    :param parentPt: 
    :param txtString: 
    :return: 
    '''
    #先计算出父子节点间的中间位置
    xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
    yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
    #在计算出的位置绘制文字信息
    createPlot.ax1.text(xMid,yMid,txtString)

3.绘制树

def plotTree(myTree,parenPt,nodeTxt):
    '''
    绘制树
    :param myTree: 
    :param parenPt: 
    :param nodeTxt: 
    :return: 
    '''

    #下面两句用来计算树的宽和高
    numLeafs=getNumLeafs(myTree)
    depth=getTreeDepth(myTree)
    firstStr=myTree.keys()[0]
    cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
    plotMidText(cntrPt,parenPt,nodeTxt)
    plotNode(firstStr,cntrPt,parenPt,decisionNode)
    secondDict=myTree[firstStr]
    #减少y轴偏移量,并标注此处将要绘制叶子节点,因为是向下绘制树因此需要依次递减y坐标值
    plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
    #开始递归遍历整棵树
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD

4.显示树

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    # totalW用来存贮树的宽度,通过树的宽度可以用于计算放置判断点的位置,主要的计算原则是将它放在所有叶子节点的中间而不仅仅是它叶子节点的中间
    plotTree.totalW = float(getNumLeafs(inTree))
    # totalD用来存贮树的深度,通过这两个参数可以计算树节点的摆放位置
    plotTree.totalD = float(getTreeDepth(inTree))
    # 记录已经绘制节点的位置,需要注意的是x轴和y轴的有效范围均为 0.0~1.0
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

其中plotTree.totalW用来存储树的宽度,plotTree.totalD用来存储树的深度,树的宽度用于计算放置判断点的位置,主要计算原则是放置在所有叶子节点的中间,plotTree.xOff , plotTree.yOff 表示已绘制点的坐标,以及放置下一个节点的位置,需要知道的是绘制图形的x轴和轴的范围均为0.0~1.0 。

测试一下上述代码:

mTree=retrieveTree(0)
createPlot(mTree)

绘制结果为:

现在就可以进行一个实际例子来应用决策树了,眼科医生是如何判断患者需要佩戴的隐形眼镜的镜片类型,隐形眼镜类型包括硬材质、软材质以及不适合佩戴隐形眼镜,数据在我上传的代码中名为lenses.txt的文件

我们开始测试数据:

fr=open('lenses.txt')
lenses=[inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels=['age','prescript','astigmatic','tearRate']
lensesTree=createTree(lenses,lensesLabels)
treePlotter.createPlot(lensesTree)

输出的图像为:

可以看出沿着不同的分支我们可以得到不同患者需要佩戴的隐形眼镜类型,而且医生也只需要问四个问题即可确定患者需要佩戴的隐形眼镜类型。

点我开启源码传送门!

如有问题请与我联系喔~请发送邮件到如下邮箱,我看到后会及时回复。

bubblyyi@outlook.com

欢迎转载,转载时请标明出。

查看更多文章请访问我的个人博客 www.bubblyyi.com

请喝咖啡☕️

打赏

4 Comments

JamesMark进行回复 取消回复

您的电子邮箱地址不会被公开。 必填项已用*标注