西瓜书7.3 朴素贝叶斯分类器
实现拉普拉斯修正的朴素贝叶斯分类器
-
定义输入
输入为Object二维数组objects,objects[i][j]表示第i个西瓜的第j个属性
Object[][]objects={ {"青绿", "蜷缩", "浊响", "清晰", "凹陷", "硬滑", 0.697, 0.460, "好瓜"}, {"乌黑", "蜷缩", "沉闷", "清晰", "凹陷", "硬滑", 0.774, 0.376, "好瓜"}, {"乌黑", "蜷缩", "浊响", "清晰", "凹陷", "硬滑", 0.634, 0.264, "好瓜"}, {"青绿", "蜷缩", "沉闷", "清晰", "凹陷", "硬滑", 0.608, 0.318, "好瓜"}, {"浅白", "蜷缩", "浊响", "清晰", "凹陷", "硬滑", 0.556, 0.215, "好瓜"}, {"青绿", "稍蜷", "浊响", "清晰", "稍凹", "软粘", 0.403, 0.237, "好瓜"}, {"乌黑", "稍蜷", "浊响", "稍糊", "稍凹", "软粘", 0.481, 0.149, "好瓜"}, {"乌黑", "稍蜷", "浊响", "清晰", "稍凹", "硬滑", 0.437, 0.211, "好瓜"}, {"乌黑", "稍蜷", "沉闷", "稍糊", "稍凹", "硬滑", 0.666, 0.091, "坏瓜"}, {"青绿", "硬挺", "清脆", "清晰", "平坦", "软粘", 0.243, 0.267, "坏瓜"}, {"浅白", "硬挺", "清脆", "模糊", "平坦", "硬滑", 0.245, 0.057, "坏瓜"}, {"浅白", "蜷缩", "浊响", "模糊", "平坦", "软粘", 0.343, 0.099, "坏瓜"}, {"青绿", "稍蜷", "浊响", "稍糊", "凹陷", "硬滑", 0.639, 0.161, "坏瓜"}, {"浅白", "稍蜷", "沉闷", "稍糊", "凹陷", "硬滑", 0.657, 0.198, "坏瓜"}, {"乌黑", "稍蜷", "浊响", "清晰", "稍凹", "软粘", 0.360, 0.370, "坏瓜"}, {"浅白", "蜷缩", "浊响", "模糊", "平坦", "硬滑", 0.593, 0.042, "坏瓜"}, {"青绿", "蜷缩", "沉闷", "稍糊", "稍凹", "硬滑", 0.719, 0.103, "坏瓜"} };为了方便,将每个objects[]封装为一个类,属性名与属性值放入一个map中
public class Entity { private Map<String,Object> map=new HashMap<String, Object>(); //种类名 private String classifyLabel; public Entity(Object[] objects, String[] labels) { for(int i=0;i<objects.length;i++){ map.put(labels[i],objects[i]); } classifyLabel=labels[labels.length-1]; } -
代码
Main
package com.fly; public class Main { public static void main(String[] args) { String[]labels={"色泽","根蒂","敲声","纹理","脐部","触感","密度","含糖率","好瓜"}; Object[][]objects={ {"青绿", "蜷缩", "浊响", "清晰", "凹陷", "硬滑", 0.697, 0.460, "好瓜"}, {"乌黑", "蜷缩", "沉闷", "清晰", "凹陷", "硬滑", 0.774, 0.376, "好瓜"}, {"乌黑", "蜷缩", "浊响", "清晰", "凹陷", "硬滑", 0.634, 0.264, "好瓜"}, {"青绿", "蜷缩", "沉闷", "清晰", "凹陷", "硬滑", 0.608, 0.318, "好瓜"}, {"浅白", "蜷缩", "浊响", "清晰", "凹陷", "硬滑", 0.556, 0.215, "好瓜"}, {"青绿", "稍蜷", "浊响", "清晰", "稍凹", "软粘", 0.403, 0.237, "好瓜"}, {"乌黑", "稍蜷", "浊响", "稍糊", "稍凹", "软粘", 0.481, 0.149, "好瓜"}, {"乌黑", "稍蜷", "浊响", "清晰", "稍凹", "硬滑", 0.437, 0.211, "好瓜"}, {"乌黑", "稍蜷", "沉闷", "稍糊", "稍凹", "硬滑", 0.666, 0.091, "坏瓜"}, {"青绿", "硬挺", "清脆", "清晰", "平坦", "软粘", 0.243, 0.267, "坏瓜"}, {"浅白", "硬挺", "清脆", "模糊", "平坦", "硬滑", 0.245, 0.057, "坏瓜"}, {"浅白", "蜷缩", "浊响", "模糊", "平坦", "软粘", 0.343, 0.099, "坏瓜"}, {"青绿", "稍蜷", "浊响", "稍糊", "凹陷", "硬滑", 0.639, 0.161, "坏瓜"}, {"浅白", "稍蜷", "沉闷", "稍糊", "凹陷", "硬滑", 0.657, 0.198, "坏瓜"}, {"乌黑", "稍蜷", "浊响", "清晰", "稍凹", "软粘", 0.360, 0.370, "坏瓜"}, {"浅白", "蜷缩", "浊响", "模糊", "平坦", "硬滑", 0.593, 0.042, "坏瓜"}, {"青绿", "蜷缩", "沉闷", "稍糊", "稍凹", "硬滑", 0.719, 0.103, "坏瓜"} }; Object[]object= {"青绿", "蜷缩", "浊响", "清晰", "凹陷", "硬滑", 0.697, 0.460, "?"}; BayesClassifier classifier=new BayesClassifier(objects,labels); Entity entity = new Entity(object, labels); System.out.println("分类结果为"+classifier.getClassify(entity)); } }Entity
package com.fly; import java.util.HashMap; import java.util.Hashtable; import java.util.Map; import java.util.Set; public class Entity { private Map<String,Object> map=new HashMap<String, Object>(); private String classifyLabel; public Entity(Object[] objects, String[] labels) { for(int i=0;i<objects.length;i++){ map.put(labels[i],objects[i]); } classifyLabel=labels[labels.length-1]; } //获取所有属性 public Set<String>getProperty(){ return map.keySet(); } //获取某个属性的值 public Object getValue(String property){ return map.get(property); } @Override public String toString() { return "Entity{" + "map=" + map + '}'; } }分类器
package com.fly; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; public class BayesClassifier { //训练集 List<Entity>entities; //类名 String classifyLabel; //分类结果的集合 Set<Object>classes=new HashSet<Object>(); public BayesClassifier(Object[][] a,String[] labels){ entities=new ArrayList<Entity>(); for(Object[] objects:a){ Entity entity=new Entity(objects,labels); entities.add(entity); classes.add(objects[objects.length-1]); } classifyLabel=labels[labels.length-1]; } //获取分类 public Object getClassify(Entity entity){ Set<String> property = entity.getProperty(); Object classify=null; double p=0; for(Object c:classes){ System.out.println("分类为"+c+"的概率是"+p(entity, c)); if(p(entity,c)>p){ p=p(entity,c); classify=c; } } return classify; } //分类为c的概率 public double p(Entity entity,Object c){ Set<String> properties = entity.getProperty(); double ans=priorP(c); for(String property:properties){ if(!property.equals(classifyLabel)){ ans*=p(property,entity.getValue(property),classifyLabel,c); } } return ans; } //先验概率 public double priorP(Object c){ int num=0; for(Entity entity:entities){ if (entity.getValue(classifyLabel).equals(c)){ num++; } } //return num*1.0/entities.size(); return (num+1)*1.0/(entities.size()+classes.size()); } //条件概率p(x|c) public double p(String label1,Object x,String label2,Object c){ double ans=0; if(! (x instanceof Number)){ double num1=0; double num2=0; for(Entity entity:entities){ Object value1 = entity.getValue(label1); Object value2 = entity.getValue(label2); //System.out.println(value1+" "+value2); if(value2.equals(c)){ num2++; if(value1.equals(x)){ num1++; } } } return (num1+1)/(num2+getN(label1)); }else{ List<Number>list=new ArrayList<Number>(); for(Entity entity:entities){ Object value1 = entity.getValue(label1); Object value2 = entity.getValue(label2); if(value2.equals(c)){ list.add((Number) value1); } } double average = getAverage(list); double deviation = getStandardDeviation(list); //System.out.println(average+" "+variance); Double number = ((Number) x).doubleValue(); double t=-(number-average)*(number-average)/2/deviation/deviation; ans=1/Math.sqrt(2*Math.PI)/deviation*Math.exp(t); return ans; } } //获取label属性值的种类数量 public double getN(String label1) { Set<Object> set = new HashSet<Object>(); for(Entity entity:entities){ set.add(entity.getValue(label1)); } return set.size(); } //获取标准差 public double getStandardDeviation(List<Number>list){ return Math.sqrt(getVariance(list)); } //获取方差 public double getVariance(List<Number>list){ double ans=0; double average = getAverage(list); for(Number number:list){ double t =((number.doubleValue()-average)*(number.doubleValue()-average)); ans+=t; } return ans/(list.size()-1); } //获取平均数 public double getAverage(List<Number>list){ double ans=0; for(Number number:list){ ans+=number.doubleValue(); } return ans/list.size(); } } -
一些问题
这里判断属性是否为离散用的 instance of Number判断,若处理数值型的标称属性会出现问题
浙公网安备 33010602011771号