机器学习之决策树三CART原理与代码实现

coding

                                             决策树系列三—CART原理与代码实现

                                                  本文系作者原创,转载请注明出处:https://www.cnblogs.com/further-further-further/p/9482885.html 

ID3,C4.5算法缺点

ID3决策树可以有多个分支,但是不能处理特征值为连续的情况。

在ID3中,每次根据“最大信息熵增益”选取当前最佳的特征来分割数据,并按照该特征的所有取值来切分,

也就是说如果一个特征有4种取值,数据将被切分4份,一旦按某特征切分后,该特征在之后的算法执行中,

将不再起作用,所以有观点认为这种切分方式过于迅速。

C4.5中是用信息增益比率(gain ratio)来作为选择分支的准则。和ID3一样,C4.5算法分类结果存在过拟合。

为了解决过拟合问题,这里介绍一种新的算法CART。

CART(classification and regression tree)

CART由特征选择、树的生成及剪枝组成,既可以用于分类也可以用于回归。

分类:如晴天/阴天/雨天、用户性别、邮件是否是垃圾邮件; 

回归:预测实数值,如明天的温度、用户的年龄等; 

CART决策树的生成就是递归地构建二叉决策树的过程,对分类、以及剪枝采用信息增益最大化准则,这里信息增益采用的基尼指数公式,

当然也可以使用ID3的信息熵公式算法。

基尼指数

分类问题中,假设有K个类别,样本点属于第类的概率为,则概率分布的基尼指数定义为

                  

对于给定的样本集合D,其基尼指数为

                  

 

生成的二叉树类似于

      

剪枝算法

CART剪枝算法从“完全生长”的决策树的底端减去一些子树,是决策树变小(模型变简单),从而能够对未知数据有更准确的预测,防止过拟合。

后剪枝需要从训练集生成一棵完整的决策树,然后自底向上对非叶子节点进行考察。利用信息增益与给定阈值判断是否将该节点对应的子树替换成叶节点。

   

代码实现

每个函数算法我基本上都做了较为详细的注释,希望对大家理解算法原理有所帮助。

因为没有上传附件功能,只能用笨办法。将原始数据复制到本地txt文件中,然后将txt格式改成dataSet.csv文件,

放在代码文件所在的路径。

  1SepalLength,SepalWidth,PetalLength,PetalWidth,Name

25.1,3.5,1.4,0.2,setosa

34.9,3,1.4,0.2,setosa

44.7,3.2,1.3,0.2,setosa

54.6,3.1,1.5,0.2,setosa

65,3.6,1.4,0.2,setosa

75.4,3.9,1.7,0.4,setosa

84.6,3.4,1.4,0.3,setosa

95,3.4,1.5,0.2,setosa

104.4,2.9,1.4,0.2,setosa

114.9,3.1,1.5,0.1,setosa

125.4,3.7,1.5,0.2,setosa

134.8,3.4,1.6,0.2,setosa

144.8,3,1.4,0.1,setosa

154.3,3,1.1,0.1,setosa

165.8,4,1.2,0.2,setosa

175.7,4.4,1.5,0.4,setosa

185.4,3.9,1.3,0.4,setosa

195.1,3.5,1.4,0.3,setosa

205.7,3.8,1.7,0.3,setosa

215.1,3.8,1.5,0.3,setosa

225.4,3.4,1.7,0.2,setosa

235.1,3.7,1.5,0.4,setosa

244.6,3.6,1,0.2,setosa

255.1,3.3,1.7,0.5,setosa

264.8,3.4,1.9,0.2,setosa

275,3,1.6,0.2,setosa

285,3.4,1.6,0.4,setosa

295.2,3.5,1.5,0.2,setosa

305.2,3.4,1.4,0.2,setosa

314.7,3.2,1.6,0.2,setosa

324.8,3.1,1.6,0.2,setosa

335.4,3.4,1.5,0.4,setosa

345.2,4.1,1.5,0.1,setosa

355.5,4.2,1.4,0.2,setosa

364.9,3.1,1.5,0.1,setosa

375,3.2,1.2,0.2,setosa

385.5,3.5,1.3,0.2,setosa

394.9,3.1,1.5,0.1,setosa

404.4,3,1.3,0.2,setosa

415.1,3.4,1.5,0.2,setosa

425,3.5,1.3,0.3,setosa

434.5,2.3,1.3,0.3,setosa

444.4,3.2,1.3,0.2,setosa

455,3.5,1.6,0.6,setosa

465.1,3.8,1.9,0.4,setosa

474.8,3,1.4,0.3,setosa

485.1,3.8,1.6,0.2,setosa

494.6,3.2,1.4,0.2,setosa

505.3,3.7,1.5,0.2,setosa

515,3.3,1.4,0.2,setosa

527,3.2,4.7,1.4,versicolor

536.4,3.2,4.5,1.5,versicolor

546.9,3.1,4.9,1.5,versicolor

555.5,2.3,4,1.3,versicolor

566.5,2.8,4.6,1.5,versicolor

575.7,2.8,4.5,1.3,versicolor

586.3,3.3,4.7,1.6,versicolor

594.9,2.4,3.3,1,versicolor

606.6,2.9,4.6,1.3,versicolor

615.2,2.7,3.9,1.4,versicolor

625,2,3.5,1,versicolor

635.9,3,4.2,1.5,versicolor

646,2.2,4,1,versicolor

656.1,2.9,4.7,1.4,versicolor

665.6,2.9,3.6,1.3,versicolor

676.7,3.1,4.4,1.4,versicolor

685.6,3,4.5,1.5,versicolor

695.8,2.7,4.1,1,versicolor

706.2,2.2,4.5,1.5,versicolor

715.6,2.5,3.9,1.1,versicolor

725.9,3.2,4.8,1.8,versicolor

736.1,2.8,4,1.3,versicolor

746.3,2.5,4.9,1.5,versicolor

756.1,2.8,4.7,1.2,versicolor

766.4,2.9,4.3,1.3,versicolor

776.6,3,4.4,1.4,versicolor

786.8,2.8,4.8,1.4,versicolor

796.7,3,5,1.7,versicolor

806,2.9,4.5,1.5,versicolor

815.7,2.6,3.5,1,versicolor

825.5,2.4,3.8,1.1,versicolor

835.5,2.4,3.7,1,versicolor

845.8,2.7,3.9,1.2,versicolor

856,2.7,5.1,1.6,versicolor

865.4,3,4.5,1.5,versicolor

876,3.4,4.5,1.6,versicolor

886.7,3.1,4.7,1.5,versicolor

896.3,2.3,4.4,1.3,versicolor

905.6,3,4.1,1.3,versicolor

915.5,2.5,4,1.3,versicolor

925.5,2.6,4.4,1.2,versicolor

936.1,3,4.6,1.4,versicolor

945.8,2.6,4,1.2,versicolor

955,2.3,3.3,1,versicolor

965.6,2.7,4.2,1.3,versicolor

975.7,3,4.2,1.2,versicolor

985.7,2.9,4.2,1.3,versicolor

996.2,2.9,4.3,1.3,versicolor

1005.1,2.5,3,1.1,versicolor

1015.7,2.8,4.1,1.3,versicolor

1026.3,3.3,6,2.5,virginica

1035.8,2.7,5.1,1.9,virginica

1047.1,3,5.9,2.1,virginica

1056.3,2.9,5.6,1.8,virginica

1066.5,3,5.8,2.2,virginica

1077.6,3,6.6,2.1,virginica

1084.9,2.5,4.5,1.7,virginica

1097.3,2.9,6.3,1.8,virginica

1106.7,2.5,5.8,1.8,virginica

1117.2,3.6,6.1,2.5,virginica

1126.5,3.2,5.1,2,virginica

1136.4,2.7,5.3,1.9,virginica

1146.8,3,5.5,2.1,virginica

1155.7,2.5,5,2,virginica

1165.8,2.8,5.1,2.4,virginica

1176.4,3.2,5.3,2.3,virginica

1186.5,3,5.5,1.8,virginica

1197.7,3.8,6.7,2.2,virginica

1207.7,2.6,6.9,2.3,virginica

1216,2.2,5,1.5,virginica

1226.9,3.2,5.7,2.3,virginica

1235.6,2.8,4.9,2,virginica

1247.7,2.8,6.7,2,virginica

1256.3,2.7,4.9,1.8,virginica

1266.7,3.3,5.7,2.1,virginica

1277.2,3.2,6,1.8,virginica

1286.2,2.8,4.8,1.8,virginica

1296.1,3,4.9,1.8,virginica

1306.4,2.8,5.6,2.1,virginica

1317.2,3,5.8,1.6,virginica

1327.4,2.8,6.1,1.9,virginica

1337.9,3.8,6.4,2,virginica

1346.4,2.8,5.6,2.2,virginica

1356.3,2.8,5.1,1.5,virginica

1366.1,2.6,5.6,1.4,virginica

1377.7,3,6.1,2.3,virginica

1386.3,3.4,5.6,2.4,virginica

1396.4,3.1,5.5,1.8,virginica

1406,3,4.8,1.8,virginica

1416.9,3.1,5.4,2.1,virginica

1426.7,3.1,5.6,2.4,virginica

1436.9,3.1,5.1,2.3,virginica

1445.8,2.7,5.1,1.9,virginica

1456.8,3.2,5.9,2.3,virginica

1466.7,3.3,5.7,2.5,virginica

1476.7,3,5.2,2.3,virginica

1486.3,2.5,5,1.9,virginica

1496.5,3,5.2,2,virginica

1506.2,3.4,5.4,2.3,virginica

151 5.9,3,5.1,1.8,virginica

原始数据

  1# -*- coding: utf-8 -*-

2"""

3Created on Tue Aug 14 17:36:57 2018

4

5@author: weixw

6"""

7import numpy as np

8#定义树结构,采用的二叉树,左子树:条件为true,右子树:条件为false

9#leftBranch:左子树结点

10#rightBranch:右子树结点

11#col:信息增益最大时对应的列索引

12#value:最优列索引下,划分数据类型的值

13#results:分类结果

14#summary:信息增益最大时样本信息

15#data:信息增益最大时数据集

16class Tree:

17def__init__(self, leftBranch =None, rightBranch= None, col =-1, value =None, results =None, summary =None, data =None):

18 self.leftBranch = leftBranch

19 self.rightBranch = rightBranch

20 self.col = col

21 self.value = value

22 self.results = results

23 self.summary = summary

24 self.data = data

25

26def__str__(self):

27print(u"列号:%d"%self.col)

28print(u"列划分值:%s"%self.value)

29print(u"样本信息:%s"%self.summary)

30return""

31

32

33

34#划分数据集

35def splitDataSet(dataSet, value, column):

36 leftList=[]

37 rightList=[]

38#判断value是否是数值型

39if(isinstance(value, int) or isinstance(value, float)):

40#遍历每一行数据

41for rowData in dataSet:

42#如果某一行指定列值>=value,则将该行数据保存在leftList中,否则保存在rightList中

43if(rowData[column] >= value):

44 leftList.append(rowData)

45else:

46 rightList.append(rowData)

47#value为标称型

48else:

49#遍历每一行数据

50for rowData in dataSet:

51#如果某一行指定列值==value,则将该行数据保存在leftList中,否则保存在rightList中

52if(rowData[column] == value):

53 leftList.append(rowData)

54else:

55 rightList.append(rowData)

56return leftList, rightList

57

58#统计标签类每个样本个数

59'''

60该函数是计算gini值的辅助函数,假设输入的dataSet为为['A', 'B', 'C', 'A', 'A', 'D'],

61则输出为['A':3,' B':1, 'C':1, 'D':1],这样分类统计dataSet中每个类别的数量

62'''

63def calculateDiffCount(dataSet):

64 results = {}

65for data in dataSet:

66# data[-1] 是数据集最后一列,也就是标签类

67if data[-1] notin results:

68 results.setdefault(data[-1], 1)

69else:

70 results[data[-1]] += 1

71return results

72

73

74#基尼指数公式实现

75def gini(dataSet):

76# 计算gini的值(Calculate GINI)

77#数据所有行

78 length = len(dataSet)

79#标签列合并后的数据集

80 results = calculateDiffCount(dataSet)

81 imp = 0.0

82for i in results:

83 imp += results[i] / length * results[i] / length

84return 1 - imp

85

86#生成决策树

87'''算法步骤'''

88'''根据训练数据集,从根结点开始,递归地对每个结点进行以下操作,构建二叉决策树:

891 设结点的训练数据集为D,计算现有特征对该数据集的信息增益。此时,对每一个特征A,对其可能取的

90 每个值a,根据样本点对A >=a 的测试为“是”或“否”将D分割成D1和D2两部分,利用基尼指数计算信息增益。

912 在所有可能的特征A以及它们所有可能的切分点a中,选择信息增益最大的特征及其对应的切分点作为最优特征

92 与最优切分点,依据最优特征与最优切分点,从现结点生成两个子结点,将训练数据集依特征分配到两个子结点中去。

933 对两个子结点递归地调用1,2,直至满足停止条件。

944 生成CART决策树。

95'''''''''''''''''''''

96#evaluationFunc= gini :采用的是基尼指数来衡量信息关注度

97def buildDecisionTree(dataSet, evaluationFunc = gini):

98#计算基础数据集的基尼指数

99 baseGain = evaluationFunc(dataSet)

100#计算每一行的长度(也就是列总数)

101 columnLength = len(dataSet[0])

102#计算数据项总数

103 rowLength = len(dataSet)

104#初始化

105 bestGain = 0.0 #信息增益最大值

106 bestValue = None #信息增益最大时的列索引,以及划分数据集的样本值

107 bestSet = None # 信息增益最大,听过样本值划分数据集后的数据子集

108#标签列除外(最后一列),遍历每一列数据

109for col in range(columnLength -1):

110#获取指定列数据

111 colSet = [example[col] for example in dataSet]

112#获取指定列样本唯一值

113 uniqueColSet = set(colSet)

114#遍历指定列样本集

115for value in uniqueColSet:

116#分割数据集

117 leftDataSet, rightDataSet = splitDataSet(dataSet, value, col)

118#计算子数据集概率,python3 "/"除号结果为小数

119 prop = len(leftDataSet)/rowLength

120#计算信息增益

121 infoGain = baseGain - prop*evaluationFunc(leftDataSet) - (1 - prop)*evaluationFunc(rightDataSet)

122#找出信息增益最大时的列索引,value,数据子集

123if(infoGain > bestGain):

124 bestGain = infoGain

125 bestValue = (col, value)

126 bestSet = (leftDataSet, rightDataSet)

127#结点信息

128# nodeDescription = {'impurity:%.3f'%baseGain,'sample:%d'%rowLength}

129 nodeDescription = {'impurity': '%.3f' % baseGain, 'sample': '%d' % rowLength}

130#数据行标签类别不一致,可以继续分类

131#递归必须有终止条件

132if bestGain > 0:

133#递归,生成左子树结点,右子树结点

134 leftBranch = buildDecisionTree(bestSet[0], evaluationFunc)

135 rightBranch = buildDecisionTree(bestSet[1], evaluationFunc)

136return Tree(leftBranch = leftBranch, rightBranch = rightBranch, col = bestValue[0]

137 , value = bestValue[1], summary = nodeDescription, data = bestSet)

138else:

139#数据行标签类别都相同,分类终止

140return Tree(results = calculateDiffCount(dataSet), summary = nodeDescription, data = dataSet)

141

142def createTree(dataSet, evaluationFunc=gini):

143# 递归建立决策树, 当gain=0,时停止回归

144#计算基础数据集的基尼指数

145 baseGain = evaluationFunc(dataSet)

146#计算每一行的长度(也就是列总数)

147 columnLength = len(dataSet[0])

148#计算数据项总数

149 rowLength = len(dataSet)

150#初始化

151 bestGain = 0.0 #信息增益最大值

152 bestValue = None #信息增益最大时的列索引,以及划分数据集的样本值

153 bestSet = None # 信息增益最大,听过样本值划分数据集后的数据子集

154#标签列除外(最后一列),遍历每一列数据

155for col in range(columnLength -1):

156#获取指定列数据

157 colSet = [example[col] for example in dataSet]

158#获取指定列样本唯一值

159 uniqueColSet = set(colSet)

160#遍历指定列样本集

161for value in uniqueColSet:

162#分割数据集

163 leftDataSet, rightDataSet = splitDataSet(dataSet, value, col)

164#计算子数据集概率,python3 "/"除号结果为小数

165 prop = len(leftDataSet)/rowLength

166#计算信息增益

167 infoGain = baseGain - prop*evaluationFunc(leftDataSet) - (1 - prop)*evaluationFunc(rightDataSet)

168#找出信息增益最大时的列索引,value,数据子集

169if(infoGain > bestGain):

170 bestGain = infoGain

171 bestValue = (col, value)

172 bestSet = (leftDataSet, rightDataSet)

173

174 impurity = u'%.3f' % baseGain

175 sample = '%d' % rowLength

176

177if bestGain > 0:

178 bestFeatLabel =u'serial:%s\nimpurity:%s\nsample:%s'%(bestValue[0], impurity,sample)

179 myTree = {bestFeatLabel:{}}

180 myTree[bestFeatLabel][bestValue[1]] = createTree(bestSet[0], evaluationFunc)

181 myTree[bestFeatLabel]['no'] = createTree(bestSet[1], evaluationFunc)

182return myTree

183else:#递归需要返回值

184 bestFeatValue =u'%s\nimpurity:%s\nsample:%s'%(str(calculateDiffCount(dataSet)), impurity,sample)

185return bestFeatValue

186

187#分类测试:

188'''根据给定测试数据遍历二叉树,找到符合条件的叶子结点'''

189'''例如测试数据为[5.9,3,4.2,1.75],按照训练数据生成的决策树分类的顺序为

190 第2列对应测试数据4.2 =>与决策树根结点(2)的value(3)比较,>=3则遍历左子树,否则遍历右子树,

191 叶子结点就是结果'''

192def classify(data, tree):

193#判断是否是叶子结点,是就返回叶子结点相关信息,否就继续遍历

194if tree.results != None:

195return u"%s\n%s"%(tree.results, tree.summary)

196else:

197 branch = None

198 v = data[tree.col]

199#数值型数据

200if isinstance(v, int) or isinstance(v, float):

201if v >= tree.value:

202 branch = tree.leftBranch

203else:

204 branch = tree.rightBranch

205else:#标称型数据

206if v == tree.value:

207 branch = tree.leftBranch

208else:

209 branch = tree.rightBranch

210return classify(data, branch)

211

212def loadCSV(fileName):

213def convertTypes(s):

214 s = s.strip()

215try:

216return float(s) if'.'in s else int(s)

217except ValueError:

218return s

219 data = np.loadtxt(fileName, dtype='str', delimiter=',')

220 data = data[1:, :]

221 dataSet =([[convertTypes(item) for item in row] for row in data])

222return dataSet

223

224#多数表决器

225#列中相同值数量最多为结果

226def majorityCnt(classList):

227import operator

228 classCounts = {}

229for value in classList:

230if(value notin classCounts.keys()):

231 classCounts[value] = 0

232 classCounts[value] +=1

233 sortedClassCount = sorted(classCounts.items(),key = operator.itemgetter(1),reverse =True)

234return sortedClassCount[0][0]

235

236#剪枝算法(前序遍历方式:根=>左子树=>右子树)

237'''算法步骤

2381. 从二叉树的根结点出发,递归调用剪枝算法,直至左、右结点都是叶子结点

2392. 计算父节点(子结点为叶子结点)的信息增益infoGain

2403. 如果infoGain < miniGain,则选取样本多的叶子结点来取代父节点

2414. 循环1,2,3,直至遍历完整棵树

242'''''''''

243def prune(tree, miniGain, evaluationFunc = gini):

244print(u"当前结点信息:")

245print(str(tree))

246#如果当前结点的左子树不是叶子结点,遍历左子树

247if(tree.leftBranch.results == None):

248print(u"左子树结点信息:")

249print(str(tree.leftBranch))

250 prune(tree.leftBranch, miniGain, evaluationFunc)

251#如果当前结点的右子树不是叶子结点,遍历右子树

252if(tree.rightBranch.results == None):

253print(u"右子树结点信息:")

254print(str(tree.rightBranch))

255 prune(tree.rightBranch, miniGain, evaluationFunc)

256#左子树和右子树都是叶子结点

257if(tree.leftBranch.results != None and tree.rightBranch.results != None):

258#计算左叶子结点数据长度

259 leftLen = len(tree.leftBranch.data)

260#计算右叶子结点数据长度

261 rightLen = len(tree.rightBranch.data)

262#计算左叶子结点概率

263 leftProp = leftLen/(leftLen + rightLen)

264#计算该结点的信息增益(子类是叶子结点)

265 infoGain = (evaluationFunc(tree.leftBranch.data + tree.rightBranch.data) -

266 leftProp*evaluationFunc(tree.leftBranch.data) - (1 - leftProp)*evaluationFunc(tree.rightBranch.data))

267#信息增益 < 给定阈值,则说明叶子结点与其父结点特征差别不大,可以剪枝

268if(infoGain < miniGain):

269#合并左右叶子结点数据

270 dataSet = tree.leftBranch.data + tree.rightBranch.data

271#获取标签列

272 classLabels = [example[-1] for example in dataSet]

273#找到样本最多的标签值

274 keyLabel = majorityCnt(classLabels)

275#判断标签值是左右叶子结点哪一个

276if keyLabel in tree.leftBranch.results:

277#左叶子结点取代父结点

278 tree.data = tree.leftBranch.data

279 tree.results = tree.leftBranch.results

280 tree.summary = tree.leftBranch.summary

281else:

282#右叶子结点取代父结点

283 tree.data = tree.rightBranch.data

284 tree.results = tree.rightBranch.results

285 tree.summary = tree.rightBranch.summary

286 tree.leftBranch = None

287 tree.rightBranch = None

288

289

290

CART算法实现

  1'''

2Created on Oct 14, 2010

3

4@author: Peter Harrington

5'''

6import matplotlib.pyplot as plt

7

8 decisionNode = dict(boxstyle="sawtooth", fc="0.8")

9 leafNode = dict(boxstyle="circle", fc="0.7")

10 arrow_args = dict(arrowstyle="<-")

11

12#获取树的叶子节点

13def getNumLeafs(myTree):

14 numLeafs = 0

15#dict转化为list

16 firstSides = list(myTree.keys())

17 firstStr = firstSides[0]

18 secondDict = myTree[firstStr]

19for key in secondDict.keys():

20#判断是否是叶子节点(通过类型判断,子类不存在,则类型为str;子类存在,则为dict)

21if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes

22 numLeafs += getNumLeafs(secondDict[key])

23else: numLeafs +=1

24return numLeafs

25

26#获取树的层数

27def getTreeDepth(myTree):

28 maxDepth = 0

29#dict转化为list

30 firstSides = list(myTree.keys())

31 firstStr = firstSides[0]

32 secondDict = myTree[firstStr]

33for key in secondDict.keys():

34if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes

35 thisDepth = 1 + getTreeDepth(secondDict[key])

36else: thisDepth = 1

37if thisDepth > maxDepth: maxDepth = thisDepth

38return maxDepth

39

40def plotNode(nodeTxt, centerPt, parentPt, nodeType):

41 createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',

42 xytext=centerPt, textcoords='axes fraction',

43 va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )

44

45def plotMidText(cntrPt, parentPt, txtString):

46 xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]

47 yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]

48 createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

49

50def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on

51 numLeafs = getNumLeafs(myTree) #this determines the x width of this tree

52 depth = getTreeDepth(myTree)

53 firstSides = list(myTree.keys())

54 firstStr = firstSides[0] #the text label for this node should be this

55 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)

56 plotMidText(cntrPt, parentPt, nodeTxt)

57 plotNode(firstStr, cntrPt, parentPt, decisionNode)

58 secondDict = myTree[firstStr]

59 plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD

60for key in secondDict.keys():

61if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes

62 plotTree(secondDict[key],cntrPt,str(key)) #recursion

63else: #it's a leaf node print the leaf node

64 plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW

65 plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)

66 plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))

67 plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

68#if you do get a dictonary you know it's a tree, and the first element will be another dict

69#绘制决策树 样例1

70def createPlot(inTree):

71 fig = plt.figure(1, facecolor='white')

72 fig.clf()

73 axprops = dict(xticks=[], yticks=[])

74 createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks

75#createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses

76#宽,高间距

77 plotTree.totalW = float(getNumLeafs(inTree))-3

78 plotTree.totalD = float(getTreeDepth(inTree))-2

79# plotTree.totalW = float(getNumLeafs(inTree))

80# plotTree.totalD = float(getTreeDepth(inTree))

81 plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;

82 plotTree(inTree, (0.95,1.0), '')

83 plt.show()

84

85#绘制决策树 样例2

86def createPlot1(inTree):

87 fig = plt.figure(1, facecolor='white')

88 fig.clf()

89 axprops = dict(xticks=[], yticks=[])

90 createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks

91#createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses

92#宽,高间距

93 plotTree.totalW = float(getNumLeafs(inTree))-4.5

94 plotTree.totalD = float(getTreeDepth(inTree)) -3

95 plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;

96 plotTree(inTree, (1.0,1.0), '')

97 plt.show()

98

99#绘制树的根节点和叶子节点(根节点形状:长方形,叶子节点:椭圆形)

100#def createPlot():

101# fig = plt.figure(1, facecolor='white')

102# fig.clf()

103# createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses

104# plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)

105# plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)

106# plt.show()

107

108def retrieveTree(i):

109 listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},

110 {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}

111 ]

112return listOfTrees[i]

113

114#thisTree = retrieveTree(0)

115#createPlot(thisTree)

116#createPlot()

117#myTree = retrieveTree(0)

118#numLeafs =getNumLeafs(myTree)

119#treeDepth =getTreeDepth(myTree)

120#print(u"叶子节点数目:%d"% numLeafs)

121#print(u"树深度:%d"%treeDepth)

绘制决策树

 1# -*- coding: utf-8 -*-

2"""

3Created on Wed Aug 15 14:16:59 2018

4

5@author: weixw

6"""

7import myCart as mc

8if__name__ == '__main__':

9import treePlotter as tp

10 dataSet = mc.loadCSV("dataSet.csv")

11 myTree = mc.createTree(dataSet, evaluationFunc=gini)

12print(u"myTree:%s"%myTree)

13#绘制决策树

14print(u"绘制决策树:")

15 tp.createPlot1(myTree)

16 decisionTree = mc.buildDecisionTree(dataSet, evaluationFunc=gini)

17 testData = [5.9,3,4.2,1.75]

18 r = mc.classify(testData, decisionTree)

19print(u"分类后测试结果:")

20print(r)

21print()

22 mc.prune(decisionTree, 0.4)

23 r1 = mc.classify(testData, decisionTree)

24print(u"剪枝后测试结果:")

25print(r1)

测试代码

运行结果

为什么我要再写个createTree(dataSet, evaluationFunc=gini)函数,是因为绘制决策树createPlot1(myTree)输入参数需要是json结构数据。

将生成的决策树变为可视图形,这样更直观。

当然,也可以将自定义树对象信息打印出来,我在代码里已加入打印语句。

打印结果如下,因为屏幕的原因,没有全部粘贴出来,大家可以对照决策树绘制图,这样可以相互印证,加深理解。

 

在未做剪枝处理时的分类测试结果如下:

 

剪枝处理后的分类测试结果:

可以看出,{'versicolor': 47}取代了父结点serial:3,成为新的叶子结点。

参考文献

《统计学习方法》

机器学习实战" title="机器学习实战">机器学习实战》

不要让懒惰占据你的大脑,不要让妥协拖垮你的人生。青春就是一张票,能不能赶上时代的快车,你的步伐掌握在你的脚下。

以上是 机器学习之决策树三CART原理与代码实现 的全部内容, 来源链接: utcz.com/z/509539.html

回到顶部