Big Data Analytics 简明教程

Big Data Analytics - Decision Trees

决策树是一种用于分类或回归等监督学习问题的算法。决策树或分类树是一种树,其中每个内部(非叶)节点都标记有输入特征。源自标记有某个特征的节点的弧线都标记有该特征的每个可能值。树的每个叶都标记有类别或类别概率分布。

可以通过基于属性值检验将源集拆分为子集来“学习”树。此过程在称为 recursive partitioning 的递归方式中对每个派生子集重复执行。当某个节点的子集具有目标变量的所有相同值,或当拆分不再为预测增加价值时,递归完成。这种由上至下归纳决策树的过程是贪心算法的一个示例,并且是学习决策树最常见的策略。

数据挖掘中使用的决策树主要有以下两种类型:

  1. Classification tree −当响应是名词变量时,例如电子邮件是否是垃圾邮件。

  2. Regression tree −当预测结果可以被认为是一个实数时(例如工人的工资)。

决策树是一种简单的方法,因此有一些问题。此问题之一是决策树产生的结果模型中的高方差。为了减轻此问题,开发了决策树集成方法。目前广泛使用两组集成方法−

  1. Bagging decision trees −这些树用于通过重复带替换地重新采样训练数据来构建多个决策树,并为共识预测投票树。此算法被称为随机森林。

  2. Boosting decision trees −梯度提升结合弱学习器;在此情况下,反复将决策树组合成一个单一的强学习器。它为数据拟合弱树,并通过迭代拟合弱学习器来校正前一个模型的误差。

# Install the party package
# install.packages('party')
library(party)
library(ggplot2)

head(diamonds)
# We will predict the cut of diamonds using the features available in the
diamonds dataset.
ct = ctree(cut ~ ., data = diamonds)

# plot(ct, main="Conditional Inference Tree")
# Example output
# Response:  cut
# Inputs:  carat, color, clarity, depth, table, price, x, y, z

# Number of observations:  53940
#
# 1) table <= 57; criterion = 1, statistic = 10131.878
#   2) depth <= 63; criterion = 1, statistic = 8377.279
#     3) table <= 56.4; criterion = 1, statistic = 226.423
#       4) z <= 2.64; criterion = 1, statistic = 70.393
#         5) clarity <= VS1; criterion = 0.989, statistic = 10.48
#           6) color <= E; criterion = 0.997, statistic = 12.829
#             7)*  weights = 82
#           6) color > E

#Table of prediction errors
table(predict(ct), diamonds$cut)
#            Fair  Good Very Good Premium Ideal
# Fair       1388   171        17       0    14
# Good        102  2912       499      26    27
# Very Good    54   998      3334     249   355
# Premium      44   711      5054   11915  1167
# Ideal        22   114      3178    1601 19988
# Estimated class probabilities
probs = predict(ct, newdata = diamonds, type = "prob")
probs = do.call(rbind, probs)
head(probs)