# 《集体智慧编程》读书笔记6

## 第六部分 决策树建模

### 预测注册用户

slashdot USA yes 18 None
digg USA yes 24 Basic
kiwitobes France yes 23 Basic
(direct) New Zealand no 12 None
(direct) UK no 21 Basic
slashdot France yes 19 None
digg USA no 18 None
google UK no 18 None
kiwitobes UK no 19 None
digg New Zealand yes 12 Basic
slashdot UK no 21 None
google UK yes 18 Basic
kiwitobes France yes 19 Basic

public class TreePredict
{
public static List<object[]> MyData = new List<object[]>()
{
new object[]{"slashdot","USA","yes",18,"None"},
new object[]{"digg","USA","yes",24,"Basic"},
new object[]{"kiwitobes","France","yes",23,"Basic"},
new object[]{"(direct)","New Zealand","no",12,"None"},
new object[]{"(direct)","UK","no",21,"Basic"},
new object[]{"slashdot","France","yes",19,"None"},
new object[]{"digg","USA","no",18,"None"},
new object[]{"kiwitobes","UK","no",19,"None"},
new object[]{"digg","New Zealand","yes",12,"Basic"},
new object[]{"slashdot","UK","no",21,"None"},
new object[]{"kiwitobes","France","yes",19,"Basic"}
};
}


#### 实现决策树算法

public class DecisionNode
{
public DecisionNode()
{
}

public DecisionNode(int col, object value, DecisionNode tb, DecisionNode fb)
{
Col = col;
Value = value;

Tb = tb;
Fb = fb;
}

public DecisionNode(Dictionary<string, int> results)
{
Results = results;
}

public int Col { get; set; }
public object Value { get; set; }
public Dictionary<string, int> Results { get; set; }
public DecisionNode Tb { get; set; }
public DecisionNode Fb { get; set; }
}

• Col表示这个节点判断条件对应的上面表格的列的索引
• Value 表示为了使判断条件为true，需要的值是多少
• Tb 当此节点验证结果为true时对应的子节点
• Fb 当此节点验证结果为false时对应的子节点
• Results 只有叶节点这个属性不为空，表示这个分支的结果

#### 训练决策树

// 在某一列上对数据集合进行拆分，能处理数值型数据或名词性数据(字符串)
public Tuple<List<object[]>, List<object[]>> DivideSet(List<object[]> rows, int column, object value)
{
// 定义一个lambda用于判断记录应该归为第一组还是第二组（即匹配参考值还是不匹配）
Func<object[], bool> splitFunc = null;
if (value is int)
splitFunc = r => Convert.ToInt32(r[column]) >= Convert.ToInt32(value);
else if (value is float)
splitFunc = r => Convert.ToSingle(r[column]) >= Convert.ToSingle(value);
else
splitFunc = r => r[column].ToString() == value.ToString();

// 将数据集拆分成两个集合并返回
var set1 = rows.Where(r => splitFunc(r)).ToList();
var set2 = rows.Where(r => !splitFunc(r)).ToList();
return Tuple.Create(set1, set2);
}


var treePredict = new TreePredict();
var splitSet = treePredict.DivideSet(TreePredict.MyData, 2, "yes");
Action<object[]> printRow = r => { Console.WriteLine($"{r[0]},{r[1]},{r[2]},{r[3]},{r[4]}"); }; Console.WriteLine("set1:"); splitSet.Item1.ForEach(r => printRow(r)); Console.WriteLine("set2:"); splitSet.Item2.ForEach(r => printRow(r));  拆分结果如下： 是否阅读过FAQ - True是否阅读过FAQ - False None Premium Premium None Basic Basic Basic Premium None None Basic None Basic None 可以看到拆分的两个集合中不同的结果混在一起。这说明按照“是否阅读过FAQ”这个列来分类不合理。 所以需要找一种方法来确定使用哪个列来对数据表进行划分。 #### 选择拆分方案 好的拆分方案得到的集合中，结果列的混杂程度应尽可能的小。 首先我们添加一个函数UniqueCountsTreePredict对结果列进行计数： // 对结果列（最后一列）进行计数 public Dictionary<string, int> UniqueCounts(List<object[]> rows) { var results = new Dictionary<string, int>(); foreach (var row in rows) { // 计数结果在最后一列 var r = row.Last().ToString(); if (!results.ContainsKey(r)) results.Add(r, 0); results[r] += 1; } return results; }  这个函数的作用就是找出所有不同的结果，并对结果进行计数。这个结果将用于计算数据集合的混杂程度。 下面介绍两种度量混杂程度的算法：基尼不纯度(Giniimpurity)和熵(entropy)。 #### 基尼不纯度 基尼不纯度是指将来自集合中某个值随机应用于集合中某一数据项的预期误差率。 假如集合中的每个数据都是同一分类，那么推测总是正确的，所以预期误差率总是为0。而如果有4种类别的数据且数量相等，则只有25%的概率推测正确，所以误差率为75%。 在TreePredict中添加基尼不纯度的计算方法GiniImpurity // 随机放置的数据项出现于错误分类中的概率 public float GiniImpurity(List<object[]> rows) { var total = rows.Count; var counts = UniqueCounts(rows); var imp = 0f; foreach (var k1 in counts.Keys) { var p1 = counts[k1] / (float)total; foreach (var k2 in counts.Keys) { if (k1 == k2) continue; var p2 = counts[k2] / (float)total; imp += p1 * p2; } } return imp; }  这个函数累加了某一行数据被随机分配到错误结果的概率，得到总概率。 这个概率越高说明数据拆分越不理想，而0说明每一行数据都被分到正确的集合中。 #### 熵 熵在信息理论中用于表示集合的无序程度，和这里要求的混杂程度很类似。 将计算熵的Entropy方法加入到TreePredict中： // 熵是遍历所有可能结果之后所得到的p(x)log(p(x))之和 public float Entropy(List<object[]> rows) { Func<float, float> log2 = x => (float)(Math.Log(x) / Math.Log(2)); var results = UniqueCounts(rows); // 开始计算熵值 var ent = 0f; foreach (var r in results.Keys) { var p = results[r] / (float)rows.Count; ent -= p * log2(p); } return ent; }  如果所有结果都相同，上面方法计算的熵为0，而如果数据集越是混乱，相应熵就越高。我们的目标就是拆分数据集并降低熵。 我们通过下面的代码测试基尼不纯度和熵的计算： var treePredict = new TreePredict(); var gini = treePredict.GiniImpurity(TreePredict.MyData); Console.WriteLine(gini); var entr = treePredict.Entropy(TreePredict.MyData); Console.WriteLine(entr); var setTuple = treePredict.DivideSet(TreePredict.MyData, 2, "yes"); gini = treePredict.GiniImpurity(setTuple.Item1); Console.WriteLine(gini); entr = treePredict.Entropy(setTuple.Item1); Console.WriteLine(entr);  在现实中熵的使用更为普遍，后文将以熵作为度量混杂程度的标准。 #### 递归方式构造决策树 有了判断集合混杂度的方法，我们可以通过计算群组拆分后熵的信息增益来判断拆分的好坏。 信息增益是指整个群组的熵与拆分后两个新群组的熵的加权平均值之间的差。差值即信息增益越大说明拆分效果越好。我们在每个列上都进行拆分尝试并计算信息增益，最终找出信息增益最大的列。 对于新得到的子集合，如果子集合可以继续拆分（如结果有不同值存在才有必要继续拆分），将在其上继续这个拆分过程直到信息增益为0。 我们在TreePredict添加一个递归函数BuildTree来实现这个递归构建树的过程。 public DecisionNode BuildTree(List<object[]> rows, Func<List<object[]>, float> scoref = null) { if (scoref == null) scoref = Entropy; var rowsCount = rows.Count; if (rowsCount == 0) return new DecisionNode(); var currentScore = scoref(rows); //定义一些变量记录最佳拆分条见 var bestGain = 0f; Tuple<int, object> bestCriteria = null; Tuple<List<object[]>, List<object[]>> bestSets = null; var columnCount = rows[0].Length - 1; for (int i = 0; i < columnCount; i++) { // 在当前列中生成一个由不同值构成的序列 var columnValues = new List<object>(); if (rows[0][i] is int) columnValues = rows.Select(r => r[i]).Cast<int>().Distinct().Cast<object>().ToList(); else if (rows[0][i] is float) columnValues = rows.Select(r => r[i]).Cast<float>().Distinct().Cast<object>().ToList(); else columnValues = rows.Select(r => r[i].ToString()).Distinct().Cast<object>().ToList(); // 根据这一列中的每个值，尝试对数据集进行拆分 foreach (var value in columnValues) { var setTuple = DivideSet(rows, i, value); var set1 = setTuple.Item1; var set2 = setTuple.Item2; //信息增益 var p = set1.Count / (float)rowsCount; var gain = currentScore - p * scoref(set1) - (1 - p) * scoref(set2); if (gain > bestGain && set1.Count > 0 && set2.Count > 0) { bestGain = gain; bestCriteria = Tuple.Create(i, value); bestSets = setTuple; } } } // 创建子分支 if (bestGain > 0) { var trueBranch = BuildTree(bestSets.Item1); var falseBranch = BuildTree(bestSets.Item2); return new DecisionNode( col: bestCriteria.Item1, value: bestCriteria.Item2, tb: trueBranch, fb: falseBranch ); } else { return new DecisionNode(UniqueCounts(rows)); } }  代码中，我们在每一列上，按照列中每一个不同的值进行拆分尝试，并找到一个使信息增益最大的拆分方式。递归这个过程直到树构建完成。 我们可以通过下面的代码测试决策树的构造： var treePredict = new TreePredict(); treePredict.BuildTree(TreePredict.MyData);  很显然，现在看不到任何可视化的结果，下一节将编写代码以文本方式打印决策树 #### 展示决策树 仍然是在TreePredict中建立新方法，PrintTree方法将以文本方式展示树，由于是遍历树，这个函数自然也是一个递归函数。 public void PrintTree(DecisionNode tree, string indent = "") { //是叶节点吗？ if (tree.Results != null) Console.WriteLine(JsonConvert.SerializeObject(tree.Results)); else { //打印判断条件 Console.WriteLine($"{tree.Col}:{tree.Value}? ");

//打印分支
Console.Write($"{indent}T->"); PrintTree(tree.Tb, indent + " "); Console.Write($"{indent}F->");
PrintTree(tree.Fb, indent + "  ");
}
}


var treePredict = new TreePredict();
var tree= treePredict.BuildTree(TreePredict.MyData);
treePredict.PrintTree(tree);


#### 使用决策树分类

public Dictionary<string,int> Classify(object[] observation, DecisionNode tree)
{
if (tree.Results != null)
return tree.Results;
var v = observation[tree.Col];
DecisionNode branch;
if (v is int || v is float)
{
var val = v is int ? Convert.ToInt32(v) : Convert.ToSingle(v);
var treeVal = tree.Value is int ? Convert.ToInt32(tree.Value) : Convert.ToSingle(tree.Value);
branch = val >= treeVal ? tree.Tb : tree.Fb;
}
else
{
branch = v.ToString() == tree.Value.ToString() ? tree.Tb : tree.Fb;
}
return Classify(observation, branch);
}


var treePredict = new TreePredict();
var tree= treePredict.BuildTree(TreePredict.MyData);
var result = treePredict.Classify(new object[] {"(direct)","USA","yes",5}, tree);
Console.WriteLine(JsonConvert.SerializeObject(result));


### 决策树剪枝

public void Prune(DecisionNode tree, float mingain)
{
//如果分支不是叶节点，则进行剪枝操作
if (tree.Tb.Results == null)
Prune(tree.Tb, mingain);
if (tree.Fb.Results == null)
Prune(tree.Fb, mingain);

//如果两个子分支都是叶节点，则判断是否需要合并
if (tree.Tb.Results != null && tree.Fb.Results != null)
{
//构造合并后的数据集
IEnumerable<object[]> tb = new List<object[]>();
IEnumerable<object[]> fb = new List<object[]>();
tb = tree.Tb.Results.Aggregate(tb, (current, tbKvPair)
=> current.Union(ArrayList.Repeat(new object[] {tbKvPair.Key}, tbKvPair.Value).Cast<object[]>()));
fb = tree.Fb.Results.Aggregate(fb, (current, tbKvPair)
=> current.Union(ArrayList.Repeat(new object[] { tbKvPair.Key }, tbKvPair.Value).Cast<object[]>()));

//检查熵增加情况
var mergeNode = tb.Union(fb).ToList();
var delta = Entropy(mergeNode) - (Entropy(tb.ToList()) + Entropy(fb.ToList())/2);
Debug.WriteLine(delta);
if (delta < mingain)
{
//合并分支
tree.Tb = null;
tree.Fb = null;
tree.Results = UniqueCounts(mergeNode);
}
}
}


var treePredict = new TreePredict();
var tree= treePredict.BuildTree(TreePredict.MyData);
treePredict.Prune(tree,0.1f);
treePredict.PrintTree(tree);
Console.WriteLine("--------------------------");
treePredict.Prune(tree, 1.01f);
treePredict.PrintTree(tree);


### 处理缺失数据

public Dictionary<string, float> MdClassify(object[] observation, DecisionNode tree)
{
if (tree.Results != null)
return tree.Results.ToDictionary(r=>r.Key,r=>(float)r.Value);
var v = observation[tree.Col];
if (v == null)
{
var tr = MdClassify(observation, tree.Tb);
var fr = MdClassify(observation, tree.Fb);
var tcount = tr.Values.Count;
var fcount = fr.Values.Count;
var tw = tcount / (float)(tcount + fcount);
var fw = fcount / (float)(tcount + fcount);
var result = tr.ToDictionary(trKvp => trKvp.Key, trKvp => trKvp.Value*tw);
foreach (var frKvp in fr)
{
if (!result.ContainsKey(frKvp.Key))
result[frKvp.Key] += frKvp.Value * fw;
}
return result;
}
else
{
DecisionNode branch;
if (v is int || v is float)
{
var val = v is int ? Convert.ToInt32(v) : Convert.ToSingle(v);
var treeVal = tree.Value is int ? Convert.ToInt32(tree.Value) : Convert.ToSingle(tree.Value);
branch = val >= treeVal ? tree.Tb : tree.Fb;
}
else
{
branch = v.ToString() == tree.Value.ToString() ? tree.Tb : tree.Fb;
}
return MdClassify(observation, branch);
}
}


var treePredict = new TreePredict();
var tree = treePredict.BuildTree(TreePredict.MyData);
var result = treePredict.MdClassify(new object[] { "google", null, "yes", null }, tree);
Console.WriteLine(JsonConvert.SerializeObject(result));
result = treePredict.MdClassify(new object[] { "google", "France", null, null }, tree);
Console.WriteLine(JsonConvert.SerializeObject(result));


### 数值型结果

public float Variance(List<object[]> rows)
{
if (rows.Count == 0) return 0;
var data = rows.Select(r => Convert.ToSingle(r.Last())).ToList();
var mean = data.Average();
var variance = data.Select(d => (float) Math.Pow(d - mean, 2)).Average();
return variance;
}


### 总结

posted @ 2017-02-19 17:55  hystar  阅读(412)  评论(0编辑  收藏  举报