文章

机器学习之决策树笔记

机器学习之决策树笔记

环境准备

平台:windows10 64 位
IDE:Pycharm
Python 版本:Python 3.5
github 代码源代码 my_DecisionTree.py


决策树

决策树(decision tree)是一种常见的机器学习方法,其实在生活中我们已经用到了决策树相关的知识,比如说,女生相亲时的想法就是决策树的一种体现:

决策树

那么对于一般女生来说,首选的就是看对方年龄,根据年龄是否超过 30 来决定见不见南方,如果超过,不见,如果没有超过就继续判断,依次类推,这就是一个决策树。那么对于上面的决策树来说,要解决的问题就是对当前男生的数据构建一颗决策树,用来对未来男生进行分类,即当要进行下一次相亲时,根据决策树来判断是否见面。一般的,一颗决策树包含一个根节点、若干个内部节点和若干个叶节点,叶节点对应决策及结果,其他节点对应于一个属性的测试。可以很轻松的理解到,上图中年龄为根节点,长相、收入、公务员为内部节点,见或不见为叶节点。


基本算法

决策树基本算法图


数据划分

weka数据集 要想建立一个决策树,首先需要建立一个根节点,对于以上的数据集来说就是先根据那个类别来划分,即‘outlook、temperature、humidity、windy’中的哪个类别作为根节点,这就需要一个量来作为度量,那就是信息增益,以信息增益来作为划分依据的成为ID3(Iterative Dichotomiser 迭代二分器)算法,还有以增益率(gain ratio)来划分的称为C4.5 算法,以基尼系数(Gini index)划分的成为CART 决策树


信息增益

信息熵

信息熵(information entropy)是用来度量信息源的不确定度。它的公式如下:

\[Ent(D)=-\sum_{k=1}^{|y|}p_k\log_2p_k \tag{1}\]

其中$p_k$为数据集$D$中的$k$类样本所占的比例,$Ent(D)$越小,则$D$的纯度越高。 对于 weka 数据集来说,该数据集共有 14 个样本,用来预测某一天是否合适外出游玩,那么显然这里就得$k=1,2$,即外出或不外出两种情况,外出所占比例为$\frac{9}{14}$,而不外出所占比例为$\frac{5}{14}$,根据以上公式根节点的信息增益可以计算出来为: \(Ent(D)=-(\frac{9}{14}\log_2\frac{9}{14}+\frac{5}{14}\log_2\frac{5}{14}=0.94)\) 即为该数据集的信息熵

条件熵

条件熵的公式如下: \(Ent(D)=-\sum_{v=1}^V\frac{|D^v|}{|D|}Ent(D^v)\tag{2}\) 以‘outlook’为例来计算在 outlook 条件下的信息熵,那么就要计算当前属性集合中的每个属性的信息增益,即‘sunny’‘overcast’‘rainy’这三个属性的每一个属性的信息增益,先来计算 sunny 属性的信息增益,将天气为sunny的保留,得到如下所示数据集:

outlook temperature humidity windy play
sunny hot high FALSE no
sunny hot high TRUE no
sunny mild high FALSE no
sunny cool normal FALSE yes
sunny mild normal TRUE yes

这里在sunny的情况下,外出所占比例为$\frac{2}{5}$,不外出比例为$\frac{3}{5}$,故可以计算出此时的信息熵$Ent(D^1)$,同样的情况下可以计算出overcastrainy的信息熵$Ent(D^2)$,$Ent(D^3)$,然后sunnyoutlook比例为$\frac{5}{14}$,overcastraniy占比为$\frac{4}{14}$,$\frac{5}{14}$,那么此时的信息增益即为:

$Gain(D,outlook)=Ent(D)-(\frac{5}{14}Ent(D^1)+\frac{4}{14}Ent(D^2)+\frac{5}{14}Ent(D^3)$

这就是outlook属性的信息增益,同样可以计算出其他属性的信息增益,将其比较大小,找出最大值作为第一次分类的特征,即作为根节点。然后依次迭代循环,计算出整个决策树。


Python 代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
dataSet = [['sunny', 'hot', 'high', 'FALSE', 'no'],
           ['sunny', 'hot', 'high', 'TRUE', 'no'],
           ['overcast', 'hot', 'high', 'FALSE', 'yes'],
           ['rainy', 'mild', 'high', 'FALSE', 'yes'],
           ['rainy', 'cool', 'normal', 'FALSE', 'yes'],
           ['rainy', 'cool', 'normal', 'TRUE', 'no'],
           ['overcast', 'cool', 'normal', 'TRUE', 'yes'],
           ['sunny', 'mild', 'high', 'FALSE', 'no'],
           ['sunny', 'cool', 'normal', 'FALSE', 'yes'],
           ['rainy', 'mild', 'normal', 'FALSE', 'yes'],
           ['sunny', 'mild', 'normal', 'TRUE', 'yes'],
           ['overcast', 'mild', 'high', 'TRUE', 'yes'],
           ['overcast', 'hot', 'normal', 'FALSE', 'yes'],
           ['rainy', 'mild', 'high', 'TRUE', 'no']]
labels = ['outlook','temperature','humidity','windy','play']

首先创建数据集,以列表的方式存储。

1
2
3
4
5
6
7
8
9
10
11
12
13
def calcentropy(data):
    numentropy = len(data)
    labelCounts = {}
    for featVec in data:  # 数据集的每一行
        currentlabel = featVec[-1]  # 每一行的最后一个类别
        if currentlabel not in labelCounts.keys():  # 如果当前标签不在字典的关键字中
            labelCounts[currentlabel] = 0  # 让当前标签为0,实际上是增加字典的关键字
        labelCounts[currentlabel] += 1  # 如果在字典里,就增加1,实际上是统计每个标签出现的次数
    shannonEnt = 0  # 香农熵,即最后的返回结果
    for key in labelCounts:     # 遍历labelCount中的每一个关键词
        prob = labelCounts[key] / numentropy    # 用关键词的个数除以总个数得到概率
        shannonEnt -= prob * np.log2(prob)      # 求信息熵,即香农熵
return shannonEnt

然后创建一个函数用来计算信息熵,这个函数有个特点就是用 for 循环和字典来记录一列数据中每个属性出现的次数,然后通过出现的次数除以总次数来计算概率,最后求得信息熵。

1
2
3
4
5
6
7
8
def splitData(data, axis, value):       # 划分数据集
    retData = []
    for featVec in data:    # 用featVec表示每一个样本
        if featVec[axis] == value:      # 如果value的值等于里面样本的特征
            reduceFeat = featVec[:axis]     # 用reduceFeat补齐这个特征之前的所有特征
            reduceFeat.extend(featVec[axis + 1:])       # 补齐这个特征之后的所有特征
            retData.append(reduceFeat)      # 用retData来表示去掉这个特征的最终特征
return retData

该函数为划分数据集函数,这里面用 reduceFeat 来存储去掉某个特征之后的数据集,然后返回该数据集用以迭代循环。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def chooseBestFeatSplit(data):
    numFeats = len(data[0]) - 1     # 特征的数目,减去最后一个特征
    bestGain = 0
    for i in range(numFeats):       # i:0-3,在本例中只有4个特征
        featlist = [example[i] for example in data]     # 列表解析取出data中的一列数据
        uniqueVals = set(featlist)          # 转变为集合
        newEntropy = 0
        for value in uniqueVals:            # 遍历每个特征内所有集合的元素,对于i=0是,uniqueVals={summy,rainy,overcast}
            subData = splitData(data, i, value)      # 对第一个特征中rainy划分数据集
            prob = len(subData) / float(len(data))
            newEntropy += prob * calcentropy(subData)   # 求出条件熵
        infoGain = baseGain - newEntropy                # 计算信息增益
        if infoGain > bestGain:                         # 取出最大的信息增益
            bestGain = infoGain
            bestfeat = i
return bestfeat                                     # 返回当信息增益最大是的特征类别

该函数是选择最佳的特征来分类,是整个步骤中最重要的一步,这里它用了一个列表解析取出数据集中某一列数据,然后转变为集合,通过比较每个类别的信息增益来返回最佳的特征

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def creatTree(data, labels):
    classList = [example[-1] for example in data]       # 遍历数据集最后一列
    if classList.count(classList[0]) == len(classList):     # 类别中的概率是否确定,即概率为1时,提前到达叶节点
        return classList[0]                             # 返回该类别
    if len(data[0]) == 1:       # 是否只剩下一个类别,到达最终的叶节点
        return majorityCnt(classList)
    bestFeature = chooseBestFeatSplit(data)         # 选择最好的特征
    bestFeatureLabel = labels[bestFeature]          # 最好的特征表现
    myTree = {bestFeatureLabel:{}}                  # 创建树
    del(labels[bestFeature])                        # 删除已经计算过的特征
    featValues = [example[bestFeature] for example in data]     # 遍历根据最佳特征分割完成后的类别中的属性
    uniqueVals = set(featValues)
    for value in uniqueVals:
        sublabels = labels[:]
        myTree[bestFeatureLabel][value] = creatTree(splitData(data,bestFeature,value),sublabels)
return myTree

该函数即为上图中决策树学习的基本算法,可以看到,它以 2 个条件做为迭代终止的条件,第一个是样本在类别中取值相同,这就说明已经到达了叶子节点,做出了最终决策,该阶段任务完成。第二个是只剩下一个类别时,迭代终止,到达最终的叶子节点。 以上程序为树的核心程序,剩下的用 matplot 绘图工具将生成的树绘出即可,程序如下

1
2
3
4
if __name__ == '__main__':
    baseGain = calcentropy(dataSet)
    myTree = creatTree(dataSet,labels)
    print(myTree)

最后运行结果出的决策树如下: 最终决策树图


参考书目

本文由作者按照 CC BY 4.0 进行授权