本文翻译自 Decision Trees,原载于 Hacker News。
决策树是如何工作的
我们在上一节中了解了决策树的高层运作方式:从上到下,它创建一系列顺序规则,将数据分割成边界清晰的区域用于分类。但面对如此多的可能选项,算法究竟如何确定在哪里分割数据?
在深入了解这一点之前,我们需要先理解熵(Entropy)的概念。
熵:衡量信息的纯度
熵用于衡量某个变量或事件所包含的信息量。我们会用它来识别包含大量相似(纯净)或不相似(不纯)元素的区域。
给定一组概率为 $(p_1, p_2, \dots, p_n)$ 的事件,总熵 $H$ 可以表示为加权概率的负数之和:
\[H = -\sum_{i=1}^{n} p_i \log_2(p_i)\]熵的三个重要性质
-
$H=0$ 当且仅当除了一个 $p_i$ 之外,其他所有概率都为零,而这个概率值为 1。这意味着当结果没有不确定性时,熵为零——样本完全可预测。
-
$H$ 在所有 $p_i$ 相等时达到最大值。这是最不确定、或者说最”不纯”的情况。
-
任何使概率趋向于平均化的变化都会增加 $H$。
熵可以用来量化带标签数据点集合的不纯度:包含多个类别的节点是不纯的,而只包含一个类别的节点是纯的。
💡 直观理解:纯样本的熵为零,而不纯样本的熵值较大。这正是熵为我们做的事情——衡量一组样本的纯度(或不纯度)。
信息增益:选择最佳分割点
理解了熵之后,我们可以描述训练决策树的逻辑了。顾名思义,信息增益(Information Gain)衡量我们获得的信息量。它使用熵来计算:用分割前的数据熵减去分割后各个分区的熵,然后选择使熵减少最大(即信息增益最大)的分割方式。
计算信息增益的核心算法叫做 ID3(Iterative Dichotomiser 3)。这是一个递归过程,从树的根节点开始,以贪婪的方式自顶向下遍历所有非叶节点,在每个深度计算熵的变化:
\[\Delta IG = H_{\text{parent}} - \frac{1}{N}\sum_{\text{children}} N_{\text{child}} \cdot H_{\text{child}}\]ID3 算法的具体步骤
-
计算每个特征的熵:遍历数据集的每个特征。
-
尝试所有可能的分割:使用不同的特征和切分值将数据集分成子集。对每个分割,使用上述公式计算信息增益 $\Delta IG$(分割前后熵的差值)。对于分割后所有子节点的总熵,使用加权平均,考虑 $N_{\text{child}}$(即 N 个样本中有多少落在该子节点上)。
-
选择最佳分割:找出产生最大信息增益的分割方式,在该特征和分割值上创建决策节点。
-
创建叶节点:当某个子集无法进一步分割时,创建叶节点。如果是分类任务,用该节点中最常见的类别标记;如果是回归任务,用平均值标记。
-
递归处理所有子集:如果分割后某个子节点的所有元素都属于同一类型,则停止递归。也可以设置额外的停止条件,比如要求每个叶节点的最小样本数,或限制树的最大深度。
一个具体的例子
让我们回顾第一个决策节点是如何选择的:Diameter ≤ 0.45。
这个条件是怎么选出来的?是通过最大化信息增益得出的。
对于 Diameter 特征的每个可能的分割值,都会产生不同的信息增益值。ID3 算法会选择信息增益最大的分割点——在这个例子中,信息增益在 Diameter = 0.45 处达到峰值 0.574。
基尼不纯度:另一种选择
除了熵,构建决策树还可以使用基尼不纯度(Gini Impurity)。这也是一种衡量信息的方法,可以看作香农熵的变体。
使用熵或基尼不纯度训练的决策树通常表现相当,只有在少数情况下结果会有显著差异:
- 对于不平衡数据集,熵可能更谨慎
- 基尼不纯度训练更快,因为它不使用对数运算
从另一个角度看决策树
让我们从数据点的角度来理解决策树。从上到下,随着数据被分配到不同的决策节点和叶节点,待分类的样本逐渐缩小。这种方式下,我们可以追踪训练数据点经过的完整路径。
注意,并非每个叶节点都是纯的——我们之前提到过(下一节会详细讨论),我们不希望决策树太深,因为这样的模型可能无法很好地泛化到新数据。
决策树的局限性
决策树有很多优点:
- ✅ 简单、易于解释
- ✅ 训练速度快
- ✅ 需要最少的数据预处理
- ✅ 能轻松处理异常值
但它有一个主要局限:相比其他预测器,决策树不稳定。
对数据扰动极其敏感
决策树对数据中的微小扰动非常敏感:训练样本的轻微变化可能导致决策树结构的剧烈变化。
你可以自己验证:仅仅对 5% 的训练样本添加小的随机高斯扰动,就会产生完全不同的决策树集合。
在原始形式下,决策树是不稳定的。
过拟合问题
如果不加控制,ID3 算法会不断工作以最小化熵。它会持续分割数据,直到所有叶节点都完全纯净——即只包含一个类别。这个过程可能产生非常深且复杂的决策树。
结合前述的高方差问题,这两者都是不可取的,因为它们导致预测器无法清晰地区分数据中的持久模式和随机模式——这就是过拟合(Overfitting)问题。
过拟合是有问题的,因为它意味着我们的模型在面对新数据时表现不佳。
解决方案
剪枝(Pruning)
可以通过以下方式防止决策树过度生长:
- 限制最大深度
- 限制可以创建的叶节点数量
- 设置每个叶节点的最小样本数
- 不允许样本数过少的叶节点存在
随机森林(Random Forest)
对于高方差问题呢?遗憾的是,这是训练单个决策树时的内在特性。
讽刺的是,缓解扰动引起的不稳定性的一种方法是在训练过程中引入额外的随机层。
在实践中,这可以通过创建在数据集的略微不同版本上训练的决策树集合来实现,它们的组合预测不会严重受到高方差的影响。这种方法为我们打开了迄今为止最成功的机器学习算法之一的大门:随机森林。
总结
本文涵盖了决策树算法的核心内容:
-
工作原理:决策树通过一系列条件规则,将特征空间反复分割成不同区域来分类数据
-
熵的概念:用于衡量给定数据样本纯度(或不纯度)的流行指标
-
信息增益与 ID3 算法:决策树如何使用熵来计算信息增益,确定选择哪些条件规则
-
局限性:决策树容易过拟合,对数据扰动敏感
-
优化方向:剪枝可以防止过度生长,而随机森林通过集成学习解决高方差问题
🎯 给开发者的建议:在实际项目中,scikit-learn 的
DecisionTreeClassifier和RandomForestClassifier是很好的起点。记得设置max_depth、min_samples_leaf等参数来防止过拟合。
注:原文包含精彩的交互式可视化动画,强烈建议访问原文体验完整效果。