机器决策树算法后期预剪枝后剪枝
1.在进行决策树剪枝操作之前,我们先来看看为什么要进行决策树的剪枝操作?
那是因为决策树的过拟合的风险很大。因为理论上来说可以将数据完全分的开,如果树足够大,每个叶子节点就剩下了一个数据。那么,这就会造成模型在训练集上的拟合效果很好,但是泛化能力很差,对新样本的适应能力不足。所以,对决策树进行剪枝,可以降低过拟合的风险。
2.决策树的剪枝策略(预剪枝/后剪枝)
- 预剪枝
预剪枝在决策树生成过程中对每个节点先进行估计,如果划分能带来准确率上升则划分,否者不划分节点;后剪枝则是先使用训练集生成一棵决策树,再使用测试集对其节点进行评估,若将子树替换为叶子结点能带来准确率的提升则替换。
预剪枝优点:限制树的深度,叶子节点个数,叶子节点的样本数,信息增益量
- 后剪枝
后剪枝则是先从训练集生成一棵完整的决策树,然后自底向上地对非叶结点进行考察,若将该结点对应的子树替换为叶结点能带来决策树泛化性能提升,则将该子树替换为叶结点。
后剪枝和前剪枝的不同在于后剪枝是在生成决策树后再进行剪枝。顺序是由下到上
代码实现:
训练集:
测试集:
则后剪枝代码为:
-
# 后剪枝
-
def post_prunning(tree , test_data , test_label , names):
-
newTree = tree.copy() #copy是浅拷贝
-
names = np.asarray(names)
-
# 取决策节点的名称 即特征的名称
-
featName = list(tree.keys())[0]
-
# 取特征的列
-
featCol = np.argwhere(names == featName)[0][0]
-
names = np.delete(names, [featCol]) #删掉使用过的特征
-
newTree[featName] = tree[featName].copy() #取值
-
featValueDict = newTree[featName] #当前特征下面的取值情况
-
featPreLabel = featValueDict.pop("prun_label") #如果当前节点剪枝的话是什么标签,并删除_vpdl
-
-
# 分割测试数据 如果有数据 则进行测试或递归调用:
-
split_data = drop_exist_feature(test_data,featName) #删除该特征,按照该特征的取值重新划分数据
-
split_data = dict(split_data)
-
-
for featValue in featValueDict.keys(): #每个特征的值
-
if type(featValueDict[featValue]) == dict: #如果下一层还是字典,说明还是子树
-
-
split_data_feature = split_data[featValue] #特征某个取值的数据,如“脐部”特征值为“凹陷”的数据
-
split_data_lable = split_data[featValue].iloc[:, -1].values
-
# 递归到下一个节点
-
newTree[featName][featValue] = post_prunning(featValueDict[featValue],split_data_feature,split_data_lable,split_data_feature.columns)
-
-
# 根据准确率判断是否剪枝,注意这里的准确率是到达该节点数据预测正确的准确率,而不是整体数据集的准确率
-
# 因为在修改当前节点时,走到其他节点的数据的预测结果是不变的,所以只需要计算走到当前节点的数据预测对了没有即可
-
ratioPreDivision = equalNums(test_label, featPreLabel) / test_label.size #判断测试集的数据如果剪枝的准确率
-
-
#计算如果该节点不剪枝的准确率
-
ratioAfterDivision = predict_more(newTree, test_data, test_label)
-
-
if ratioAfterDivision < ratioPreDivision:
-
newTree = featPreLabel # 返回剪枝结果,其实也就是走到当前节点的数据最多的那一类
-
-
return newTree
-
-
if __name__ == '__main__':
-
#读取数据
-
train_data = pd.read_csv('./train_data.csv')
-
test_data = pd.read_csv('./test_data.csv')
-
test_data_label = test_data.iloc[:, -1].values
-
names = test_data.columns
-
-
dicision_Tree = {"脐部": {"prun_label": 1
-
, '凹陷': {'色泽':{"prun_label": 1, '青绿': 1, '乌黑': 1, '浅白': 0}}
-
, '稍凹': {'根蒂':{"prun_label": 1
-
, '稍蜷': {'色泽': {"prun_label": 1
-
, '青绿': 1
-
, '乌黑': {'纹理': {"prun_label": 1
-
, '稍糊': 1, '清晰': 0, '模糊': 1}}
-
, '浅白': 1}}
-
, '蜷缩': 0
-
, '硬挺': 1}}
-
, '平坦': 0}}
-
print('剪枝前的决策树:')
-
print(dicision_Tree)
-
print('剪枝前的测试集准确率: {}'.format(predict_more(dicision_Tree, test_data, test_data_label)))
-
-
print('-'*20 '剪枝' '-'*20)
-
new_tree = post_prunning(dicision_Tree,test_data , test_data_label , names)
-
print('剪枝后的决策树:')
-
print(new_tree)
-
print('剪枝后的测试集准确率: {}'.format(predict_more(new_tree, test_data, test_data_label)))
运行结果:
剪枝后决策树不仅更加轻量,而且对于测试集的预测准确率从0.428提升到了0.714
这篇好文章是转载于:学新通技术网
- 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
- 本站站名: 学新通技术网
- 本文地址: /boutique/detail/tanhggcffj
系列文章
更多
同类精品
更多
-
photoshop保存的图片太大微信发不了怎么办
PHP中文网 06-15 -
《学习通》视频自动暂停处理方法
HelloWorld317 07-05 -
Android 11 保存文件到外部存储,并分享文件
Luke 10-12 -
word里面弄一个表格后上面的标题会跑到下面怎么办
PHP中文网 06-20 -
photoshop扩展功能面板显示灰色怎么办
PHP中文网 06-14 -
微信公众号没有声音提示怎么办
PHP中文网 03-31 -
excel下划线不显示怎么办
PHP中文网 06-23 -
excel打印预览压线压字怎么办
PHP中文网 06-22 -
怎样阻止微信小程序自动打开
PHP中文网 06-13 -
TikTok加速器哪个好免费的TK加速器推荐
TK小达人 10-01