文章中如果有图看不到,可以点这里去 csdn 看看。从那边导过来的,文章太多,没法一篇篇修改好。

数据结构与算法 决策树训练完整指南:从理论到实践

1. 决策树训练的基本原理

决策树训练的核心是通过递归地选择最佳特征分割数据,使得每个子集尽可能"纯净"(包含同一类别的样本)。

1.1 训练过程的数学基础

决策树训练依赖于几个关键的不纯度度量指标:

信息熵(Entropy)
Entropy(S)=−∑i=1cpilog⁡2pi Entropy(S) = -\sum_{i=1}^{c} p_i \log_2 p_i Entropy(S)=i=1cpilog2pi
其中 pip_ipi 是第 iii 类样本在集合 SSS 中的比例

基尼不纯度(Gini Impurity)
Gini(S)=1−∑i=1cpi2 Gini(S) = 1 - \sum_{i=1}^{c} p_i^2 Gini(S)=1i=1cpi2

信息增益(Information Gain)
IG(S,A)=Entropy(S)−∑v∈Values(A)∣Sv∣∣S∣Entropy(Sv) IG(S, A) = Entropy(S) - \sum_{v \in Values(A)} \frac{|S_v|}{|S|} Entropy(S_v) IG(S,A)=Entropy(S)vValues(A)SSvEntropy(Sv)
其中 AAA 是特征,SvS_vSv 是特征 AAA 取值为 vvv 的样本子集

2. 完整训练流程详解

2.1 数据准备阶段

public class DataPreprocessor {
    // 处理缺失值的完整实现
    public static List<DataPoint> handleMissingValues(List<DataPoint> data, 
                                                    Map<String, Object> defaultValues) {
        List<DataPoint> cleanedData = new ArrayList<>();
        
        for (DataPoint point : data) {
            DataPoint cleanedPoint = new DataPoint();
            
            // 处理每个特征的缺失值
            if (point.outlook == null) {
                cleanedPoint.outlook = (Outlook) defaultValues.getOrDefault("outlook", Outlook.SUNNY);
            } else {
                cleanedPoint.outlook = point.outlook;
            }
            
            if (point.temperature == null) {
                cleanedPoint.temperature = (Temperature) defaultValues.getOrDefault("temperature", Temperature.MILD);
            } else {
                cleanedPoint.temperature = point.temperature;
            }
            
            // 类似处理其他特征...
            cleanedPoint.playTennis = point.playTennis;
            cleanedData.add(cleanedPoint);
        }
        
        return cleanedData;
    }
    
    // 特征编码:将分类变量转换为数值
    public static Map<String, Map<Object, Integer>> fitLabelEncoders(List<DataPoint> data) {
        Map<String, Map<Object, Integer>> encoders = new HashMap<>();
        
        // 对每个分类特征创建编码器
        encoders.put("outlook", createEncoder(data.stream().map(dp -> dp.outlook).collect(Collectors.toList())));
        encoders.put("temperature", createEncoder(data.stream().map(dp -> dp.temperature).collect(Collectors.toList())));
        encoders.put("humidity", createEncoder(data.stream().map(dp -> dp.humidity).collect(Collectors.toList())));
        encoders.put("wind", createEncoder(data.stream().map(dp -> dp.wind).collect(Collectors.toList())));
        
        return encoders;
    }
    
    private static Map<Object, Integer> createEncoder(List<Object> values) {
        Map<Object, Integer> encoder = new HashMap<>();
        int code = 0;
        for (Object value : values) {
            if (value != null && !encoder.containsKey(value)) {
                encoder.put(value, code++);
            }
        }
        return encoder;
    }
}

2.2 决策树训练核心算法

public class DecisionTreeTrainer {
    private int maxDepth;
    private int minSamplesSplit;
    private int minSamplesLeaf;
    private double minImpurityDecrease;
    private String criterion;
    
    public DecisionTreeTrainer(int maxDepth, int minSamplesSplit, 
                              int minSamplesLeaf, double minImpurityDecrease,
                              String criterion) {
        this.maxDepth = maxDepth;
        this.minSamplesSplit = minSamplesSplit;
        this.minSamplesLeaf = minSamplesLeaf;
        this.minImpurityDecrease = minImpurityDecrease;
        this.criterion = criterion;
    }
    
    public TreeNode buildTree(List<DataPoint> data, List<String> features, int depth) {
        // 检查停止条件
        if (shouldStop(data, features, depth)) {
            return createLeafNode(data);
        }
        
        // 选择最佳分割特征
        SplitResult bestSplit = findBestSplit(data, features);
        
        if (bestSplit == null || bestSplit.gain < minImpurityDecrease) {
            return createLeafNode(data);
        }
        
        // 创建决策节点
        DecisionNode node = new DecisionNode(bestSplit.feature, bestSplit.threshold);
        
        // 递归构建子树
        List<String> remainingFeatures = new ArrayList<>(features);
        remainingFeatures.remove(bestSplit.feature);
        
        for (Map.Entry<Object, List<DataPoint>> entry : bestSplit.splitData.entrySet()) {
            TreeNode childNode = buildTree(entry.getValue(), remainingFeatures, depth + 1);
            node.addChild(entry.getKey(), childNode);
        }
        
        return node;
    }
    
    private boolean shouldStop(List<DataPoint> data, List<String> features, int depth) {
        // 达到最大深度
        if (depth >= maxDepth) {
            return true;
        }
        
        // 样本数太少
        if (data.size() < minSamplesSplit) {
            return true;
        }
        
        // 所有样本属于同一类别
        if (isPure(data)) {
            return true;
        }
        
        // 没有可用特征
        if (features.isEmpty()) {
            return true;
        }
        
        return false;
    }
    
    private boolean isPure(List<DataPoint> data) {
        if (data.isEmpty()) return true;
        
        PlayTennis firstClass = data.get(0).playTennis;
        return data.stream().allMatch(dp -> dp.playTennis == firstClass);
    }
    
    private SplitResult findBestSplit(List<DataPoint> data, List<String> features) {
        SplitResult bestSplit = null;
        double bestGain = -Double.MAX_VALUE;
        
        for (String feature : features) {
            SplitResult split = calculateSplitForFeature(data, feature);
            if (split != null && split.gain > bestGain) {
                bestGain = split.gain;
                bestSplit = split;
            }
        }
        
        return bestSplit;
    }
    
    private SplitResult calculateSplitForFeature(List<DataPoint> data, String feature) {
        if (isCategoricalFeature(feature)) {
            return splitCategoricalFeature(data, feature);
        } else {
            return splitContinuousFeature(data, feature);
        }
    }
    
    private SplitResult splitCategoricalFeature(List<DataPoint> data, String feature) {
        Map<Object, List<DataPoint>> splitData = new HashMap<>();
        
        for (DataPoint point : data) {
            Object value = getFeatureValue(point, feature);
            splitData.computeIfAbsent(value, k -> new ArrayList<>()).add(point);
        }
        
        double gain = calculateInformationGain(data, splitData.values());
        return new SplitResult(feature, null, gain, splitData);
    }
    
    private SplitResult splitContinuousFeature(List<DataPoint> data, String feature) {
        // 对连续特征值排序
        List<DataPoint> sortedData = data.stream()
                .sorted(Comparator.comparingDouble(dp -> getNumericFeatureValue(dp, feature)))
                .collect(Collectors.toList());
        
        SplitResult bestSplit = null;
        double bestGain = -Double.MAX_VALUE;
        
        // 尝试所有可能的分割点
        for (int i = 0; i < sortedData.size() - 1; i++) {
            double currentValue = getNumericFeatureValue(sortedData.get(i), feature);
            double nextValue = getNumericFeatureValue(sortedData.get(i + 1), feature);
            
            if (currentValue != nextValue) {
                double threshold = (currentValue + nextValue) / 2.0;
                
                Map<Object, List<DataPoint>> splitData = new HashMap<>();
                List<DataPoint> left = new ArrayList<>();
                List<DataPoint> right = new ArrayList<>();
                
                for (DataPoint point : sortedData) {
                    if (getNumericFeatureValue(point, feature) <= threshold) {
                        left.add(point);
                    } else {
                        right.add(point);
                    }
                }
                
                splitData.put("left", left);
                splitData.put("right", right);
                
                double gain = calculateInformationGain(data, splitData.values());
                
                if (gain > bestGain) {
                    bestGain = gain;
                    bestSplit = new SplitResult(feature, threshold, gain, splitData);
                }
            }
        }
        
        return bestSplit;
    }
    
    private double calculateInformationGain(List<DataPoint> parent, Collection<List<DataPoint>> children) {
        double parentEntropy = calculateEntropy(parent);
        
        double weightedChildrenEntropy = 0.0;
        for (List<DataPoint> child : children) {
            double weight = (double) child.size() / parent.size();
            weightedChildrenEntropy += weight * calculateEntropy(child);
        }
        
        return parentEntropy - weightedChildrenEntropy;
    }
    
    private double calculateEntropy(List<DataPoint> data) {
        if (data.isEmpty()) return 0.0;
        
        Map<PlayTennis, Long> counts = data.stream()
                .collect(Collectors.groupingBy(dp -> dp.playTennis, Collectors.counting()));
        
        double entropy = 0.0;
        for (Long count : counts.values()) {
            double probability = (double) count / data.size();
            entropy -= probability * (Math.log(probability) / Math.log(2));
        }
        
        return entropy;
    }
    
    private double calculateGini(List<DataPoint> data) {
        if (data.isEmpty()) return 0.0;
        
        Map<PlayTennis, Long> counts = data.stream()
                .collect(Collectors.groupingBy(dp -> dp.playTennis, Collectors.counting()));
        
        double gini = 1.0;
        for (Long count : counts.values()) {
            double probability = (double) count / data.size();
            gini -= probability * probability;
        }
        
        return gini;
    }
}

2.3 高级训练技术:剪枝

public class TreePruner {
    // 成本复杂度剪枝(CCP)
    public static TreeNode costComplexityPrune(TreeNode root, List<DataPoint> validationData, double alpha) {
        if (root instanceof LeafNode) {
            return root;
        }
        
        DecisionNode node = (DecisionNode) root;
        
        // 先递归剪枝子树
        Map<Object, TreeNode> prunedChildren = new HashMap<>();
        for (Map.Entry<Object, TreeNode> entry : node.children.entrySet()) {
            TreeNode prunedChild = costComplexityPrune(entry.getValue(), validationData, alpha);
            prunedChildren.put(entry.getKey(), prunedChild);
        }
        node.children = prunedChildren;
        
        // 计算剪枝前后的成本
        double errorBefore = calculateError(node, validationData);
        int leavesBefore = countLeaves(node);
        double costBefore = errorBefore + alpha * leavesBefore;
        
        LeafNode leafAfter = new LeafNode(getMajorityClass(validationData));
        double errorAfter = calculateError(leafAfter, validationData);
        double costAfter = errorAfter + alpha; // 只有一个叶节点
        
        // 如果剪枝后成本更低,则剪枝
        if (costAfter <= costBefore) {
            return leafAfter;
        }
        
        return node;
    }
    
    // 减少错误剪枝(REP)
    public static TreeNode reducedErrorPrune(TreeNode root, List<DataPoint> validationData) {
        boolean changed;
        do {
            changed = false;
            double originalAccuracy = calculateAccuracy(root, validationData);
            
            // 尝试剪枝每个非叶子节点
            List<DecisionNode> nonLeafNodes = findAllNonLeafNodes(root);
            for (DecisionNode node : nonLeafNodes) {
                // 保存原始子节点
                Map<Object, TreeNode> originalChildren = new HashMap<>(node.children);
                
                // 临时替换为叶节点
                LeafNode tempLeaf = new LeafNode(getMajorityClassFromNode(node));
                node.children = new HashMap<>();
                node.defaultChild = tempLeaf;
                
                double newAccuracy = calculateAccuracy(root, validationData);
                
                if (newAccuracy >= originalAccuracy) {
                    // 剪枝提高了准确率,保留剪枝
                    changed = true;
                } else {
                    // 恢复原始子节点
                    node.children = originalChildren;
                    node.defaultChild = null;
                }
            }
        } while (changed);
        
        return root;
    }
    
    private static List<DecisionNode> findAllNonLeafNodes(TreeNode root) {
        List<DecisionNode> nodes = new ArrayList<>();
        findNonLeafNodesRecursive(root, nodes);
        return nodes;
    }
    
    private static void findNonLeafNodesRecursive(TreeNode node, List<DecisionNode> result) {
        if (node instanceof DecisionNode) {
            DecisionNode decisionNode = (DecisionNode) node;
            result.add(decisionNode);
            
            for (TreeNode child : decisionNode.children.values()) {
                findNonLeafNodesRecursive(child, result);
            }
        }
    }
}

2.4 完整的训练流程整合

public class CompleteTrainingPipeline {
    public static void main(String[] args) {
        // 1. 加载数据
        List<DataPoint> allData = loadTennisData();
        
        // 2. 数据预处理
        DataPreprocessor preprocessor = new DataPreprocessor();
        Map<String, Object> defaultValues = Map.of(
            "outlook", Outlook.SUNNY,
            "temperature", Temperature.MILD,
            "humidity", Humidity.NORMAL,
            "wind", Wind.WEAK
        );
        
        List<DataPoint> cleanedData = preprocessor.handleMissingValues(allData, defaultValues);
        
        // 3. 特征编码
        Map<String, Map<Object, Integer>> encoders = preprocessor.fitLabelEncoders(cleanedData);
        
        // 4. 数据集划分
        Collections.shuffle(cleanedData);
        int splitIndex = (int) (cleanedData.size() * 0.7);
        List<DataPoint> trainData = cleanedData.subList(0, splitIndex);
        List<DataPoint> testData = cleanedData.subList(splitIndex, cleanedData.size());
        
        // 进一步划分验证集用于剪枝
        int valSplitIndex = (int) (trainData.size() * 0.8);
        List<DataPoint> finalTrainData = trainData.subList(0, valSplitIndex);
        List<DataPoint> validationData = trainData.subList(valSplitIndex, trainData.size());
        
        // 5. 设置训练参数
        List<String> features = Arrays.asList("outlook", "temperature", "humidity", "wind");
        
        DecisionTreeTrainer trainer = new DecisionTreeTrainer(
            10,    // maxDepth
            2,     // minSamplesSplit
            1,     // minSamplesLeaf
            0.01,  // minImpurityDecrease
            "entropy" // criterion
        );
        
        // 6. 训练初始决策树
        TreeNode initialTree = trainer.buildTree(finalTrainData, features, 0);
        
        // 7. 剪枝优化
        TreeNode prunedTree = TreePruner.costComplexityPrune(initialTree, validationData, 0.01);
        
        // 8. 模型评估
        ModelEvaluator evaluator = new ModelEvaluator();
        double trainAccuracy = evaluator.calculateAccuracy(prunedTree, finalTrainData);
        double testAccuracy = evaluator.calculateAccuracy(prunedTree, testData);
        
        System.out.println("训练准确率: " + trainAccuracy);
        System.out.println("测试准确率: " + testAccuracy);
        
        // 9. 特征重要性分析
        Map<String, Double> featureImportance = calculateFeatureImportance(prunedTree, finalTrainData);
        System.out.println("特征重要性: " + featureImportance);
        
        // 10. 模型可视化
        prunedTree.print("");
    }
    
    private static Map<String, Double> calculateFeatureImportance(TreeNode root, List<DataPoint> data) {
        Map<String, Double> importance = new HashMap<>();
        calculateImportanceRecursive(root, data, importance);
        return importance;
    }
    
    private static void calculateImportanceRecursive(TreeNode node, List<DataPoint> data, 
                                                   Map<String, Double> importance) {
        if (node instanceof DecisionNode) {
            DecisionNode decisionNode = (DecisionNode) node;
            
            // 计算该节点的信息增益
            double nodeGain = calculateNodeGain(decisionNode, data);
            importance.put(decisionNode.feature, 
                          importance.getOrDefault(decisionNode.feature, 0.0) + nodeGain);
            
            // 递归处理子节点
            for (TreeNode child : decisionNode.children.values()) {
                calculateImportanceRecursive(child, data, importance);
            }
        }
    }
}

3. 训练决策树的最佳实践

3.1 参数调优策略

public class HyperparameterTuner {
    public static Map<String, Object> tuneParameters(List<DataPoint> trainData, 
                                                   List<DataPoint> validationData) {
        Map<String, Object> bestParams = new HashMap<>();
        double bestAccuracy = 0.0;
        
        // 定义参数搜索空间
        int[] maxDepths = {3, 5, 7, 10, 15};
        int[] minSamplesSplits = {2, 5, 10};
        double[] minImpurityDecreases = {0.0, 0.01, 0.05};
        String[] criteria = {"gini", "entropy"};
        
        // 网格搜索
        for (int maxDepth : maxDepths) {
            for (int minSamplesSplit : minSamplesSplits) {
                for (double minImpurityDecrease : minImpurityDecreases) {
                    for (String criterion : criteria) {
                        
                        DecisionTreeTrainer trainer = new DecisionTreeTrainer(
                            maxDepth, minSamplesSplit, 1, minImpurityDecrease, criterion
                        );
                        
                        TreeNode tree = trainer.buildTree(trainData, 
                            Arrays.asList("outlook", "temperature", "humidity", "wind"), 0);
                        
                        double accuracy = new ModelEvaluator().calculateAccuracy(tree, validationData);
                        
                        if (accuracy > bestAccuracy) {
                            bestAccuracy = accuracy;
                            bestParams.put("maxDepth", maxDepth);
                            bestParams.put("minSamplesSplit", minSamplesSplit);
                            bestParams.put("minImpurityDecrease", minImpurityDecrease);
                            bestParams.put("criterion", criterion);
                        }
                    }
                }
            }
        }
        
        return bestParams;
    }
}

3.2 处理过拟合的策略

  1. 提前停止:设置合适的最大深度和最小样本数
  2. 剪枝:使用成本复杂度剪枝或减少错误剪枝
  3. 特征选择:移除不重要的特征
  4. 集成方法:使用随机森林或梯度提升树

3.3 处理类别不平衡

public class ClassImbalanceHandler {
    public static List<DataPoint> handleImbalance(List<DataPoint> data, double targetRatio) {
        Map<PlayTennis, List<DataPoint>> groupedData = data.stream()
                .collect(Collectors.groupingBy(dp -> dp.playTennis));
        
        long majorityCount = groupedData.values().stream()
                .mapToLong(List::size)
                .max().orElse(0);
        
        List<DataPoint> balancedData = new ArrayList<>();
        
        for (Map.Entry<PlayTennis, List<DataPoint>> entry : groupedData.entrySet()) {
            List<DataPoint> classData = entry.getValue();
            long targetSize = (long) (majorityCount * targetRatio);
            
            if (classData.size() < targetSize) {
                // 过采样:重复样本
                while (classData.size() < targetSize) {
                    balancedData.addAll(classData);
                }
                balancedData.addAll(classData.subList(0, (int) targetSize));
            } else if (classData.size() > targetSize) {
                // 欠采样:随机选择样本
                Collections.shuffle(classData);
                balancedData.addAll(classData.subList(0, (int) targetSize));
            } else {
                balancedData.addAll(classData);
            }
        }
        
        Collections.shuffle(balancedData);
        return balancedData;
    }
}

4. 总结

训练决策树是一个系统工程,需要:

  1. 充分的数据预处理:处理缺失值、异常值、特征编码
  2. 合理的参数设置:通过交叉验证选择最佳参数
  3. 有效的过拟合控制:使用剪枝和正则化技术
  4. 全面的模型评估:使用多种指标评估模型性能
  5. 持续的性能监控:监控模型在生产环境中的表现

通过遵循这些步骤和最佳实践,您可以训练出高性能、高泛化能力的决策树模型。记住,没有一刀切的解决方案,最适合的参数和技术取决于您的具体数据和业务需求。

posted @ 2025-09-22 10:58  NeoLshu  阅读(6)  评论(0)    收藏  举报  来源