Python-如何从scikit-learn决策树中提取决策规则?

我可以从决策树中经过训练的树中提取出基本的决策规则(或“决策路径”)作为文本列表吗?

就像是:

if A>0.4 then if B<0.2 then if C>0.8 then class='X'

谢谢你的帮助。

回答:

我相信这个答案比这里的其他答案更正确:

from sklearn.tree import _tree

def tree_to_code(tree, feature_names):

tree_ = tree.tree_

feature_name = [

feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"

for i in tree_.feature

]

print "def tree({}):".format(", ".join(feature_names))

def recurse(node, depth):

indent = " " * depth

if tree_.feature[node] != _tree.TREE_UNDEFINED:

name = feature_name[node]

threshold = tree_.threshold[node]

print "{}if {} <= {}:".format(indent, name, threshold)

recurse(tree_.children_left[node], depth + 1)

print "{}else: # if {} > {}".format(indent, name, threshold)

recurse(tree_.children_right[node], depth + 1)

else:

print "{}return {}".format(indent, tree_.value[node])

recurse(0, 1)

这会打印出有效的Python函数。这是尝试返回其输入的树的示例输出,该数字介于0到10之间。

def tree(f0):

if f0 <= 6.0:

if f0 <= 1.5:

return [[ 0.]]

else: # if f0 > 1.5

if f0 <= 4.5:

if f0 <= 3.5:

return [[ 3.]]

else: # if f0 > 3.5

return [[ 4.]]

else: # if f0 > 4.5

return [[ 5.]]

else: # if f0 > 6.0

if f0 <= 8.5:

if f0 <= 7.5:

return [[ 7.]]

else: # if f0 > 7.5

return [[ 8.]]

else: # if f0 > 8.5

return [[ 9.]]

这是我在其他答案中看到的一些绊脚石:

  1. 使用tree_.threshold == -2来决定一个节点是否为叶是不是一个好主意。如果它是阈值为-2的真实决策节点,该怎么办?相反,你应该查看tree.feature或tree.children_*。
  2. 该行在features = [feature_names[i] for i in tree_.feature]我的sklearn版本中崩溃,因为某些值tree.tree_.feature是-2(特别是对于叶节点)。
  3. 递归函数中不需要有多个if语句,只需一个就可以了。

以上是 Python-如何从scikit-learn决策树中提取决策规则? 的全部内容, 来源链接: utcz.com/qa/435160.html

回到顶部