注意一下addAll,将递归中的信息可以全部以list形式返回
在决策树构造的过程中是以node为参数返回,这个是在路上随机想到的
package TreeStructure;
import java.util.ArrayList;
import java.util.List;
public class testClass {
public static void main(String[] args) {
double [][]exercise = {{1,1,0,0},{1,3,1,1},{3,2,0,0},{3,2,1,10},{3,2,1,10},{3,2,1,10},{2,2,1,1},{3,2,1,9},{2,3,0,1},{2,1,0,0},{3,2,0,1},{2,1,0,1},{1,1,0,1}};
String []Attribute = {"weather","thin","cloth","target"};
int []index = {1,0,2,3};
double [][]exerciseData = new double[exercise.length][];
for(int i = 0;i<exerciseData.length;i++){
exerciseData[i] = new double[exercise[i].length];
for(int j = 0;j<exerciseData[i].length;j++){
exerciseData[i][j] = exercise[i][index[j]];
}
}
for(int i = 0;i<exerciseData.length;i++){
for(int j = 0;j<exerciseData[i].length;j++){
System.out.print(" "+exerciseData[i][j]);
}
System.out.println();
}
DecisionTree dt = new DecisionTree();
List<ArrayList<String>> data = new ArrayList<ArrayList<String>>();
for(int i=0;i<exerciseData.length;i++){
ArrayList<String> t = new ArrayList<String>();
for(int j=0;j<exerciseData[i].length;j++){
t.add(exerciseData[i][j]+"");
}
data.add(t);
}
List<String>attribute = new ArrayList<String>();
for(int k=0;k<Attribute.length;k++){
attribute.add(Attribute[k]);
}
TreeNode n =null;
TreeNode node = dt.createDT(data,attribute,n);
double[]dataExercise = {2,3};
List list = new ArrayList();
for(int i = 0;i<dataExercise.length;i++){
list.add(dataExercise[i]);
}
node.traverse(list);
System.out.println();
}
}
package TreeStructure;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
public class DecisionTree {
public TreeNode createDT(List<ArrayList<String>> data,List<String> attributeList,TreeNode node){
System.out.println("当前的DATA为");
for(int i=0;i<data.size();i++){
ArrayList<String> temp = data.get(i);
for(int j=0;j<temp.size();j++){
System.out.print(temp.get(j)+ " ");
}
System.out.println();
}
System.out.println("---------------------------------");
System.out.println("当前的ATTR为");
for(int i=0;i<attributeList.size();i++){
System.out.print(attributeList.get(i)+ " ");
}
System.out.println();
System.out.println("---------------------------------");
//String result = InfoGain.IsPure(InfoGain.getTarget(data));
//System.out.println("***************"+result);
if(node==null){
node = new TreeNode();
node.setAttributeValue("start");
node.setNodeName("start");
}
if(attributeList.size() == 1){
int num = data.size();
for(int i = 0;i<num;i++){
TreeNode leafNode = new TreeNode();
leafNode.setAttributeValue(data.get(i).get(0));
leafNode.setNodeName("target");
node.getChildTreeNode().add(leafNode);
}
return node;
}else{
System.out.println("选择出的最大增益率属性为: " + attributeList.get(0));
//node.setAttributeValue(attributeList.get(0));
List<ArrayList<String>> resultData = null;
InfoGain gain = new InfoGain(data,attributeList);
Map<String,Long> attrvalueMap = gain.getAttributeValue(0);
for(Map.Entry<String, Long> entry : attrvalueMap.entrySet()){
resultData = gain.getData4Value(entry.getKey(), 0);
TreeNode leafNode = new TreeNode();
leafNode.setAttributeValue(entry.getKey());
leafNode.setNodeName(attributeList.get(0));
node.getChildTreeNode().add(leafNode);
System.out.println("当前为"+attributeList.get(0)+"的"+entry.getKey()+"分支。");
for (int j = 0; j < resultData.size(); j++) {
resultData.get(j).remove(0);
}
ArrayList<String> resultAttr = new ArrayList<String>(attributeList);
resultAttr.remove(0);
createDT(resultData,resultAttr,leafNode);
}
}
return node;
}
}
package TreeStructure;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
public class InfoGain {
private List<ArrayList<String>> data;
private List<String> attribute;
public InfoGain(List<ArrayList<String>> data,List<String> attribute){
this.data = new ArrayList<ArrayList<String>>();
for(int i=0;i<data.size();i++){
List<String> temp = data.get(i);
ArrayList<String> t = new ArrayList<String>();
for(int j=0;j<temp.size();j++){
t.add(temp.get(j));
}
this.data.add(t);
}
this.attribute = new ArrayList<String>();
for(int k=0;k<attribute.size();k++){
this.attribute.add(attribute.get(k));
}
/*this.data = data;
this.attribute = attribute;*/
}
public Map<String,Long> getAttributeValue(int attributeIndex){
Map<String,Long> attributeValueMap = new HashMap<String,Long>();
for(ArrayList<String> note : data){
String key = note.get(attributeIndex);
Long value = attributeValueMap.get(key);
attributeValueMap.put(key, value != null ? ++value :1L);
}
return attributeValueMap;
}
public List<ArrayList<String>> getData4Value(String attrValue,int attrIndex){
List<ArrayList<String>> resultData = new ArrayList<ArrayList<String>>();
Iterator<ArrayList<String>> iterator = data.iterator();
for(;iterator.hasNext();){
ArrayList<String> templist = iterator.next();
if(templist.get(attrIndex).equalsIgnoreCase(attrValue)){
ArrayList<String> temp = (ArrayList<String>) templist.clone();
resultData.add(temp);
}
}
return resultData;
}
public static List<String> getTarget(List<ArrayList<String>> data){
List<String> list = new ArrayList<String>();
for(ArrayList<String> temp : data){
int index = temp.size()-1 ;
if(index == -1){
break;
}
String value = temp.get(index);
list.add(value);
}
return list;
}
//判断当前纯度是否100%
public static String IsPure(List<String> list){
return list.get(0);
}
}
package TreeStructure;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
class TreeNode{
private String attributeValue;
private List<TreeNode> childTreeNode;
private List<String> pathName;
private String targetFunValue;
private String nodeName;
public TreeNode(String nodeName){
this.nodeName = nodeName;
this.childTreeNode = new ArrayList<TreeNode>();
this.pathName = new ArrayList<String>();
}
public TreeNode(){
this.childTreeNode = new ArrayList<TreeNode>();
this.pathName = new ArrayList<String>();
}
public String getAttributeValue() {
return attributeValue;
}
public void setAttributeValue(String attributeValue) {
this.attributeValue = attributeValue;
}
public List<TreeNode> getChildTreeNode() {
return childTreeNode;
}
public void setChildTreeNode(List<TreeNode> childTreeNode) {
this.childTreeNode = childTreeNode;
}
public String getTargetFunValue() {
return targetFunValue;
}
public void setTargetFunValue(String targetFunValue) {
this.targetFunValue = targetFunValue;
}
public String getNodeName() {
return nodeName;
}
public void setNodeName(String nodeName) {
this.nodeName = nodeName;
}
public List<String> getPathName() {
return pathName;
}
public void setPathName(List<String> pathName) {
this.pathName = pathName;
}
public void traverse() {
System.out.println(this.getNodeName()+": "+this.getAttributeValue());
int childNumber = this.childTreeNode.size();
System.out.println(childNumber);
for (int i = 0; i < childNumber; i++) {
TreeNode child = this.childTreeNode.get(i);
child.traverse();
}
}
public List getTarget(TreeNode node){
List a = new ArrayList();;
int childNum = node.getChildTreeNode().size();
if(node.childTreeNode.get(0).childTreeNode.size()==0){//表示node孩子的孩子为空,即node下一层为目标层
for(int i = 0;i<childNum;i++){
a.add(node.getChildTreeNode().get(i).getAttributeValue());
}
}else{
for(int i = 0;i<childNum;i++){
a.addAll(getTarget(node.getChildTreeNode().get(i)));
}
}
return a;
}
public void traverse(List list) {
if(list.size()==0){
List target = getTarget(this);
// int childlistNumber = this.childTreeNode.size();
// List a = new ArrayList();
// for(int i = 0;i<childlistNumber;i++){
// TreeNode child = this.childTreeNode.get(i);
// a.add(child.getAttributeValue());
// }
List b = new ArrayList();
// Map result = new HashMap();
for(int i = 0;i<target.size();i++){
if(!b.contains(target.get(i))){
b.add(target.get(i));
}
}
int []count = new int [b.size()];
for(int i = 0;i<b.size();i++){
for(int j = 0;j<target.size();j++){
if(b.get(i).equals(target.get(j))){
count[i] = count[i]+1;
}
}
System.out.println(b.get(i)+"的数量是: "+count[i]);
}
int maxIndex = 0;
for(int i = 1;i<count.length;i++){
if(count[maxIndex]<count[i]){
maxIndex = i;
}
}
System.out.println("选择"+b.get(maxIndex)+"为最终决策");
}else{
List a = new ArrayList();
double temp = (Double)list.get(0);
int childlistNumber = this.childTreeNode.size();
System.out.println(childlistNumber);
for(int i = 0;i<childlistNumber;i++){
TreeNode child = this.childTreeNode.get(i);
double tempchild = Double.valueOf(child.getAttributeValue());
if(temp==tempchild){
System.out.println(child.getNodeName()+": "+child.getAttributeValue());
list.remove(0);
child.traverse(list);
}
}
}
}
}