数据结构与算法 决策树训练完整指南:从理论到实践
1. 决策树训练的基本原理
决策树训练的核心是通过递归地选择最佳特征分割数据,使得每个子集尽可能"纯净"(包含同一类别的样本)。
1.1 训练过程的数学基础
决策树训练依赖于几个关键的不纯度度量指标:
信息熵(Entropy):
Entropy(S)=−∑i=1cpilog2pi Entropy(S) = -\sum_{i=1}^{c} p_i \log_2 p_i Entropy(S)=−i=1∑cpilog2pi
其中 pip_ipi 是第 iii 类样本在集合 SSS 中的比例
基尼不纯度(Gini Impurity):
Gini(S)=1−∑i=1cpi2 Gini(S) = 1 - \sum_{i=1}^{c} p_i^2 Gini(S)=1−i=1∑cpi2
信息增益(Information Gain):
IG(S,A)=Entropy(S)−∑v∈Values(A)∣Sv∣∣S∣Entropy(Sv) IG(S, A) = Entropy(S) - \sum_{v \in Values(A)} \frac{|S_v|}{|S|} Entropy(S_v) IG(S,A)=Entropy(S)−v∈Values(A)∑∣S∣∣Sv∣Entropy(Sv)
其中 AAA 是特征,SvS_vSv 是特征 AAA 取值为 vvv 的样本子集
2. 完整训练流程详解
2.1 数据准备阶段
public class DataPreprocessor {
// 处理缺失值的完整实现
public static List<DataPoint> handleMissingValues(List<DataPoint> data,
Map<String, Object> defaultValues) {
List<DataPoint> cleanedData = new ArrayList<>();
for (DataPoint point : data) {
DataPoint cleanedPoint = new DataPoint();
// 处理每个特征的缺失值
if (point.outlook == null) {
cleanedPoint.outlook = (Outlook) defaultValues.getOrDefault("outlook", Outlook.SUNNY);
} else {
cleanedPoint.outlook = point.outlook;
}
if (point.temperature == null) {
cleanedPoint.temperature = (Temperature) defaultValues.getOrDefault("temperature", Temperature.MILD);
} else {
cleanedPoint.temperature = point.temperature;
}
// 类似处理其他特征...
cleanedPoint.playTennis = point.playTennis;
cleanedData.add(cleanedPoint);
}
return cleanedData;
}
// 特征编码:将分类变量转换为数值
public static Map<String, Map<Object, Integer>> fitLabelEncoders(List<DataPoint> data) {
Map<String, Map<Object, Integer>> encoders = new HashMap<>();
// 对每个分类特征创建编码器
encoders.put("outlook", createEncoder(data.stream().map(dp -> dp.outlook).collect(Collectors.toList())));
encoders.put("temperature", createEncoder(data.stream().map(dp -> dp.temperature).collect(Collectors.toList())));
encoders.put("humidity", createEncoder(data.stream().map(dp -> dp.humidity).collect(Collectors.toList())));
encoders.put("wind", createEncoder(data.stream().map(dp -> dp.wind).collect(Collectors.toList())));
return encoders;
}
private static Map<Object, Integer> createEncoder(List<Object> values) {
Map<Object, Integer> encoder = new HashMap<>();
int code = 0;
for (Object value : values) {
if (value != null && !encoder.containsKey(value)) {
encoder.put(value, code++);
}
}
return encoder;
}
}
2.2 决策树训练核心算法
public class DecisionTreeTrainer {
private int maxDepth;
private int minSamplesSplit;
private int minSamplesLeaf;
private double minImpurityDecrease;
private String criterion;
public DecisionTreeTrainer(int maxDepth, int minSamplesSplit,
int minSamplesLeaf, double minImpurityDecrease,
String criterion) {
this.maxDepth = maxDepth;
this.minSamplesSplit = minSamplesSplit;
this.minSamplesLeaf = minSamplesLeaf;
this.minImpurityDecrease = minImpurityDecrease;
this.criterion = criterion;
}
public TreeNode buildTree(List<DataPoint> data, List<String> features, int depth) {
// 检查停止条件
if (shouldStop(data, features, depth)) {
return createLeafNode(data);
}
// 选择最佳分割特征
SplitResult bestSplit = findBestSplit(data, features);
if (bestSplit == null || bestSplit.gain < minImpurityDecrease) {
return createLeafNode(data);
}
// 创建决策节点
DecisionNode node = new DecisionNode(bestSplit.feature, bestSplit.threshold);
// 递归构建子树
List<String> remainingFeatures = new ArrayList<>(features);
remainingFeatures.remove(bestSplit.feature);
for (Map.Entry<Object, List<DataPoint>> entry : bestSplit.splitData.entrySet()) {
TreeNode childNode = buildTree(entry.getValue(), remainingFeatures, depth + 1);
node.addChild(entry.getKey(), childNode);
}
return node;
}
private boolean shouldStop(List<DataPoint> data, List<String> features, int depth) {
// 达到最大深度
if (depth >= maxDepth) {
return true;
}
// 样本数太少
if (data.size() < minSamplesSplit) {
return true;
}
// 所有样本属于同一类别
if (isPure(data)) {
return true;
}
// 没有可用特征
if (features.isEmpty()) {
return true;
}
return false;
}
private boolean isPure(List<DataPoint> data) {
if (data.isEmpty()) return true;
PlayTennis firstClass = data.get(0).playTennis;
return data.stream().allMatch(dp -> dp.playTennis == firstClass);
}
private SplitResult findBestSplit(List<DataPoint> data, List<String> features) {
SplitResult bestSplit = null;
double bestGain = -Double.MAX_VALUE;
for (String feature : features) {
SplitResult split = calculateSplitForFeature(data, feature);
if (split != null && split.gain > bestGain) {
bestGain = split.gain;
bestSplit = split;
}
}
return bestSplit;
}
private SplitResult calculateSplitForFeature(List<DataPoint> data, String feature) {
if (isCategoricalFeature(feature)) {
return splitCategoricalFeature(data, feature);
} else {
return splitContinuousFeature(data, feature);
}
}
private SplitResult splitCategoricalFeature(List<DataPoint> data, String feature) {
Map<Object, List<DataPoint>> splitData = new HashMap<>();
for (DataPoint point : data) {
Object value = getFeatureValue(point, feature);
splitData.computeIfAbsent(value, k -> new ArrayList<>()).add(point);
}
double gain = calculateInformationGain(data, splitData.values());
return new SplitResult(feature, null, gain, splitData);
}
private SplitResult splitContinuousFeature(List<DataPoint> data, String feature) {
// 对连续特征值排序
List<DataPoint> sortedData = data.stream()
.sorted(Comparator.comparingDouble(dp -> getNumericFeatureValue(dp, feature)))
.collect(Collectors.toList());
SplitResult bestSplit = null;
double bestGain = -Double.MAX_VALUE;
// 尝试所有可能的分割点
for (int i = 0; i < sortedData.size() - 1; i++) {
double currentValue = getNumericFeatureValue(sortedData.get(i), feature);
double nextValue = getNumericFeatureValue(sortedData.get(i + 1), feature);
if (currentValue != nextValue) {
double threshold = (currentValue + nextValue) / 2.0;
Map<Object, List<DataPoint>> splitData = new HashMap<>();
List<DataPoint> left = new ArrayList<>();
List<DataPoint> right = new ArrayList<>();
for (DataPoint point : sortedData) {
if (getNumericFeatureValue(point, feature) <= threshold) {
left.add(point);
} else {
right.add(point);
}
}
splitData.put("left", left);
splitData.put("right", right);
double gain = calculateInformationGain(data, splitData.values());
if (gain > bestGain) {
bestGain = gain;
bestSplit = new SplitResult(feature, threshold, gain, splitData);
}
}
}
return bestSplit;
}
private double calculateInformationGain(List<DataPoint> parent, Collection<List<DataPoint>> children) {
double parentEntropy = calculateEntropy(parent);
double weightedChildrenEntropy = 0.0;
for (List<DataPoint> child : children) {
double weight = (double) child.size() / parent.size();
weightedChildrenEntropy += weight * calculateEntropy(child);
}
return parentEntropy - weightedChildrenEntropy;
}
private double calculateEntropy(List<DataPoint> data) {
if (data.isEmpty()) return 0.0;
Map<PlayTennis, Long> counts = data.stream()
.collect(Collectors.groupingBy(dp -> dp.playTennis, Collectors.counting()));
double entropy = 0.0;
for (Long count : counts.values()) {
double probability = (double) count / data.size();
entropy -= probability * (Math.log(probability) / Math.log(2));
}
return entropy;
}
private double calculateGini(List<DataPoint> data) {
if (data.isEmpty()) return 0.0;
Map<PlayTennis, Long> counts = data.stream()
.collect(Collectors.groupingBy(dp -> dp.playTennis, Collectors.counting()));
double gini = 1.0;
for (Long count : counts.values()) {
double probability = (double) count / data.size();
gini -= probability * probability;
}
return gini;
}
}
2.3 高级训练技术:剪枝
public class TreePruner {
// 成本复杂度剪枝(CCP)
public static TreeNode costComplexityPrune(TreeNode root, List<DataPoint> validationData, double alpha) {
if (root instanceof LeafNode) {
return root;
}
DecisionNode node = (DecisionNode) root;
// 先递归剪枝子树
Map<Object, TreeNode> prunedChildren = new HashMap<>();
for (Map.Entry<Object, TreeNode> entry : node.children.entrySet()) {
TreeNode prunedChild = costComplexityPrune(entry.getValue(), validationData, alpha);
prunedChildren.put(entry.getKey(), prunedChild);
}
node.children = prunedChildren;
// 计算剪枝前后的成本
double errorBefore = calculateError(node, validationData);
int leavesBefore = countLeaves(node);
double costBefore = errorBefore + alpha * leavesBefore;
LeafNode leafAfter = new LeafNode(getMajorityClass(validationData));
double errorAfter = calculateError(leafAfter, validationData);
double costAfter = errorAfter + alpha; // 只有一个叶节点
// 如果剪枝后成本更低,则剪枝
if (costAfter <= costBefore) {
return leafAfter;
}
return node;
}
// 减少错误剪枝(REP)
public static TreeNode reducedErrorPrune(TreeNode root, List<DataPoint> validationData) {
boolean changed;
do {
changed = false;
double originalAccuracy = calculateAccuracy(root, validationData);
// 尝试剪枝每个非叶子节点
List<DecisionNode> nonLeafNodes = findAllNonLeafNodes(root);
for (DecisionNode node : nonLeafNodes) {
// 保存原始子节点
Map<Object, TreeNode> originalChildren = new HashMap<>(node.children);
// 临时替换为叶节点
LeafNode tempLeaf = new LeafNode(getMajorityClassFromNode(node));
node.children = new HashMap<>();
node.defaultChild = tempLeaf;
double newAccuracy = calculateAccuracy(root, validationData);
if (newAccuracy >= originalAccuracy) {
// 剪枝提高了准确率,保留剪枝
changed = true;
} else {
// 恢复原始子节点
node.children = originalChildren;
node.defaultChild = null;
}
}
} while (changed);
return root;
}
private static List<DecisionNode> findAllNonLeafNodes(TreeNode root) {
List<DecisionNode> nodes = new ArrayList<>();
findNonLeafNodesRecursive(root, nodes);
return nodes;
}
private static void findNonLeafNodesRecursive(TreeNode node, List<DecisionNode> result) {
if (node instanceof DecisionNode) {
DecisionNode decisionNode = (DecisionNode) node;
result.add(decisionNode);
for (TreeNode child : decisionNode.children.values()) {
findNonLeafNodesRecursive(child, result);
}
}
}
}
2.4 完整的训练流程整合
public class CompleteTrainingPipeline {
public static void main(String[] args) {
// 1. 加载数据
List<DataPoint> allData = loadTennisData();
// 2. 数据预处理
DataPreprocessor preprocessor = new DataPreprocessor();
Map<String, Object> defaultValues = Map.of(
"outlook", Outlook.SUNNY,
"temperature", Temperature.MILD,
"humidity", Humidity.NORMAL,
"wind", Wind.WEAK
);
List<DataPoint> cleanedData = preprocessor.handleMissingValues(allData, defaultValues);
// 3. 特征编码
Map<String, Map<Object, Integer>> encoders = preprocessor.fitLabelEncoders(cleanedData);
// 4. 数据集划分
Collections.shuffle(cleanedData);
int splitIndex = (int) (cleanedData.size() * 0.7);
List<DataPoint> trainData = cleanedData.subList(0, splitIndex);
List<DataPoint> testData = cleanedData.subList(splitIndex, cleanedData.size());
// 进一步划分验证集用于剪枝
int valSplitIndex = (int) (trainData.size() * 0.8);
List<DataPoint> finalTrainData = trainData.subList(0, valSplitIndex);
List<DataPoint> validationData = trainData.subList(valSplitIndex, trainData.size());
// 5. 设置训练参数
List<String> features = Arrays.asList("outlook", "temperature", "humidity", "wind");
DecisionTreeTrainer trainer = new DecisionTreeTrainer(
10, // maxDepth
2, // minSamplesSplit
1, // minSamplesLeaf
0.01, // minImpurityDecrease
"entropy" // criterion
);
// 6. 训练初始决策树
TreeNode initialTree = trainer.buildTree(finalTrainData, features, 0);
// 7. 剪枝优化
TreeNode prunedTree = TreePruner.costComplexityPrune(initialTree, validationData, 0.01);
// 8. 模型评估
ModelEvaluator evaluator = new ModelEvaluator();
double trainAccuracy = evaluator.calculateAccuracy(prunedTree, finalTrainData);
double testAccuracy = evaluator.calculateAccuracy(prunedTree, testData);
System.out.println("训练准确率: " + trainAccuracy);
System.out.println("测试准确率: " + testAccuracy);
// 9. 特征重要性分析
Map<String, Double> featureImportance = calculateFeatureImportance(prunedTree, finalTrainData);
System.out.println("特征重要性: " + featureImportance);
// 10. 模型可视化
prunedTree.print("");
}
private static Map<String, Double> calculateFeatureImportance(TreeNode root, List<DataPoint> data) {
Map<String, Double> importance = new HashMap<>();
calculateImportanceRecursive(root, data, importance);
return importance;
}
private static void calculateImportanceRecursive(TreeNode node, List<DataPoint> data,
Map<String, Double> importance) {
if (node instanceof DecisionNode) {
DecisionNode decisionNode = (DecisionNode) node;
// 计算该节点的信息增益
double nodeGain = calculateNodeGain(decisionNode, data);
importance.put(decisionNode.feature,
importance.getOrDefault(decisionNode.feature, 0.0) + nodeGain);
// 递归处理子节点
for (TreeNode child : decisionNode.children.values()) {
calculateImportanceRecursive(child, data, importance);
}
}
}
}
3. 训练决策树的最佳实践
3.1 参数调优策略
public class HyperparameterTuner {
public static Map<String, Object> tuneParameters(List<DataPoint> trainData,
List<DataPoint> validationData) {
Map<String, Object> bestParams = new HashMap<>();
double bestAccuracy = 0.0;
// 定义参数搜索空间
int[] maxDepths = {3, 5, 7, 10, 15};
int[] minSamplesSplits = {2, 5, 10};
double[] minImpurityDecreases = {0.0, 0.01, 0.05};
String[] criteria = {"gini", "entropy"};
// 网格搜索
for (int maxDepth : maxDepths) {
for (int minSamplesSplit : minSamplesSplits) {
for (double minImpurityDecrease : minImpurityDecreases) {
for (String criterion : criteria) {
DecisionTreeTrainer trainer = new DecisionTreeTrainer(
maxDepth, minSamplesSplit, 1, minImpurityDecrease, criterion
);
TreeNode tree = trainer.buildTree(trainData,
Arrays.asList("outlook", "temperature", "humidity", "wind"), 0);
double accuracy = new ModelEvaluator().calculateAccuracy(tree, validationData);
if (accuracy > bestAccuracy) {
bestAccuracy = accuracy;
bestParams.put("maxDepth", maxDepth);
bestParams.put("minSamplesSplit", minSamplesSplit);
bestParams.put("minImpurityDecrease", minImpurityDecrease);
bestParams.put("criterion", criterion);
}
}
}
}
}
return bestParams;
}
}
3.2 处理过拟合的策略
- 提前停止:设置合适的最大深度和最小样本数
- 剪枝:使用成本复杂度剪枝或减少错误剪枝
- 特征选择:移除不重要的特征
- 集成方法:使用随机森林或梯度提升树
3.3 处理类别不平衡
public class ClassImbalanceHandler {
public static List<DataPoint> handleImbalance(List<DataPoint> data, double targetRatio) {
Map<PlayTennis, List<DataPoint>> groupedData = data.stream()
.collect(Collectors.groupingBy(dp -> dp.playTennis));
long majorityCount = groupedData.values().stream()
.mapToLong(List::size)
.max().orElse(0);
List<DataPoint> balancedData = new ArrayList<>();
for (Map.Entry<PlayTennis, List<DataPoint>> entry : groupedData.entrySet()) {
List<DataPoint> classData = entry.getValue();
long targetSize = (long) (majorityCount * targetRatio);
if (classData.size() < targetSize) {
// 过采样:重复样本
while (classData.size() < targetSize) {
balancedData.addAll(classData);
}
balancedData.addAll(classData.subList(0, (int) targetSize));
} else if (classData.size() > targetSize) {
// 欠采样:随机选择样本
Collections.shuffle(classData);
balancedData.addAll(classData.subList(0, (int) targetSize));
} else {
balancedData.addAll(classData);
}
}
Collections.shuffle(balancedData);
return balancedData;
}
}
4. 总结
训练决策树是一个系统工程,需要:
- 充分的数据预处理:处理缺失值、异常值、特征编码
- 合理的参数设置:通过交叉验证选择最佳参数
- 有效的过拟合控制:使用剪枝和正则化技术
- 全面的模型评估:使用多种指标评估模型性能
- 持续的性能监控:监控模型在生产环境中的表现
通过遵循这些步骤和最佳实践,您可以训练出高性能、高泛化能力的决策树模型。记住,没有一刀切的解决方案,最适合的参数和技术取决于您的具体数据和业务需求。
本文来自博客园,作者:NeoLshu,转载请注明原文链接:https://www.cnblogs.com/neolshu/p/19120300

浙公网安备 33010602011771号