西瓜书7.3 朴素贝叶斯分类器

实现拉普拉斯修正的朴素贝叶斯分类器

  1. 定义输入

    输入为Object二维数组objects,objects[i][j]表示第i个西瓜的第j个属性

     Object[][]objects={ {"青绿", "蜷缩", "浊响", "清晰", "凹陷", "硬滑", 0.697, 0.460, "好瓜"},
                    {"乌黑", "蜷缩", "沉闷", "清晰", "凹陷", "硬滑", 0.774, 0.376, "好瓜"},
                    {"乌黑", "蜷缩", "浊响", "清晰", "凹陷", "硬滑", 0.634, 0.264, "好瓜"},
                    {"青绿", "蜷缩", "沉闷", "清晰", "凹陷", "硬滑", 0.608, 0.318, "好瓜"},
                    {"浅白", "蜷缩", "浊响", "清晰", "凹陷", "硬滑", 0.556, 0.215, "好瓜"},
                    {"青绿", "稍蜷", "浊响", "清晰", "稍凹", "软粘", 0.403, 0.237, "好瓜"},
                    {"乌黑", "稍蜷", "浊响", "稍糊", "稍凹", "软粘", 0.481, 0.149, "好瓜"},
                    {"乌黑", "稍蜷", "浊响", "清晰", "稍凹", "硬滑", 0.437, 0.211, "好瓜"},
                    {"乌黑", "稍蜷", "沉闷", "稍糊", "稍凹", "硬滑", 0.666, 0.091, "坏瓜"},
                    {"青绿", "硬挺", "清脆", "清晰", "平坦", "软粘", 0.243, 0.267, "坏瓜"},
                    {"浅白", "硬挺", "清脆", "模糊", "平坦", "硬滑", 0.245, 0.057, "坏瓜"},
                    {"浅白", "蜷缩", "浊响", "模糊", "平坦", "软粘", 0.343, 0.099, "坏瓜"},
                    {"青绿", "稍蜷", "浊响", "稍糊", "凹陷", "硬滑", 0.639, 0.161, "坏瓜"},
                    {"浅白", "稍蜷", "沉闷", "稍糊", "凹陷", "硬滑", 0.657, 0.198, "坏瓜"},
                    {"乌黑", "稍蜷", "浊响", "清晰", "稍凹", "软粘", 0.360, 0.370, "坏瓜"},
                    {"浅白", "蜷缩", "浊响", "模糊", "平坦", "硬滑", 0.593, 0.042, "坏瓜"},
                    {"青绿", "蜷缩", "沉闷", "稍糊", "稍凹", "硬滑", 0.719, 0.103, "坏瓜"}
            };
    
    

    为了方便,将每个objects[]封装为一个类,属性名与属性值放入一个map中

    public class Entity {
        private Map<String,Object> map=new HashMap<String, Object>();
        //种类名
        private String classifyLabel;
    
        public Entity(Object[] objects, String[] labels) {
            for(int i=0;i<objects.length;i++){
                map.put(labels[i],objects[i]);
            }
            classifyLabel=labels[labels.length-1];
        }
    
  2. 代码

    Main

    package com.fly;
    
    public class Main {
        public static void main(String[] args) {
            String[]labels={"色泽","根蒂","敲声","纹理","脐部","触感","密度","含糖率","好瓜"};
            Object[][]objects={ {"青绿", "蜷缩", "浊响", "清晰", "凹陷", "硬滑", 0.697, 0.460, "好瓜"},
                    {"乌黑", "蜷缩", "沉闷", "清晰", "凹陷", "硬滑", 0.774, 0.376, "好瓜"},
                    {"乌黑", "蜷缩", "浊响", "清晰", "凹陷", "硬滑", 0.634, 0.264, "好瓜"},
                    {"青绿", "蜷缩", "沉闷", "清晰", "凹陷", "硬滑", 0.608, 0.318, "好瓜"},
                    {"浅白", "蜷缩", "浊响", "清晰", "凹陷", "硬滑", 0.556, 0.215, "好瓜"},
                    {"青绿", "稍蜷", "浊响", "清晰", "稍凹", "软粘", 0.403, 0.237, "好瓜"},
                    {"乌黑", "稍蜷", "浊响", "稍糊", "稍凹", "软粘", 0.481, 0.149, "好瓜"},
                    {"乌黑", "稍蜷", "浊响", "清晰", "稍凹", "硬滑", 0.437, 0.211, "好瓜"},
                    {"乌黑", "稍蜷", "沉闷", "稍糊", "稍凹", "硬滑", 0.666, 0.091, "坏瓜"},
                    {"青绿", "硬挺", "清脆", "清晰", "平坦", "软粘", 0.243, 0.267, "坏瓜"},
                    {"浅白", "硬挺", "清脆", "模糊", "平坦", "硬滑", 0.245, 0.057, "坏瓜"},
                    {"浅白", "蜷缩", "浊响", "模糊", "平坦", "软粘", 0.343, 0.099, "坏瓜"},
                    {"青绿", "稍蜷", "浊响", "稍糊", "凹陷", "硬滑", 0.639, 0.161, "坏瓜"},
                    {"浅白", "稍蜷", "沉闷", "稍糊", "凹陷", "硬滑", 0.657, 0.198, "坏瓜"},
                    {"乌黑", "稍蜷", "浊响", "清晰", "稍凹", "软粘", 0.360, 0.370, "坏瓜"},
                    {"浅白", "蜷缩", "浊响", "模糊", "平坦", "硬滑", 0.593, 0.042, "坏瓜"},
                    {"青绿", "蜷缩", "沉闷", "稍糊", "稍凹", "硬滑", 0.719, 0.103, "坏瓜"}
            };
    
            Object[]object=  {"青绿", "蜷缩", "浊响", "清晰", "凹陷", "硬滑", 0.697, 0.460, "?"};
            BayesClassifier classifier=new BayesClassifier(objects,labels);
            Entity entity = new Entity(object, labels);
            System.out.println("分类结果为"+classifier.getClassify(entity));
        }
    
    }
    
    

    Entity

    package com.fly;
    
    import java.util.HashMap;
    import java.util.Hashtable;
    import java.util.Map;
    import java.util.Set;
    
    public class Entity {
        private Map<String,Object> map=new HashMap<String, Object>();
        private String classifyLabel;
    
        public Entity(Object[] objects, String[] labels) {
            for(int i=0;i<objects.length;i++){
                map.put(labels[i],objects[i]);
            }
            classifyLabel=labels[labels.length-1];
        }
        //获取所有属性
        public Set<String>getProperty(){
           return map.keySet();
       }
        //获取某个属性的值
       public Object getValue(String property){
           return map.get(property);
       }
    
        @Override
        public String toString() {
            return "Entity{" +
                    "map=" + map +
                    '}';
        }
    
    
    }
    
    

    分类器

    package com.fly;
    
    import java.util.ArrayList;
    import java.util.HashSet;
    import java.util.List;
    import java.util.Set;
    
    public class BayesClassifier {
        //训练集
        List<Entity>entities;
        //类名
        String classifyLabel;
        //分类结果的集合
        Set<Object>classes=new HashSet<Object>();
        public BayesClassifier(Object[][] a,String[] labels){
            entities=new ArrayList<Entity>();
            for(Object[] objects:a){
                Entity entity=new Entity(objects,labels);
                entities.add(entity);
                classes.add(objects[objects.length-1]);
            }
            classifyLabel=labels[labels.length-1];
    
        }
    
        //获取分类
        public Object getClassify(Entity entity){
            Set<String> property = entity.getProperty();
            Object classify=null;
            double p=0;
            for(Object c:classes){
                System.out.println("分类为"+c+"的概率是"+p(entity, c));
                if(p(entity,c)>p){
                    p=p(entity,c);
                    classify=c;
                }
            }
            return classify;
        }
    
        //分类为c的概率
        public double p(Entity entity,Object c){
            Set<String> properties = entity.getProperty();
            double ans=priorP(c);
            for(String property:properties){
                if(!property.equals(classifyLabel)){
                    ans*=p(property,entity.getValue(property),classifyLabel,c);
                }
            }
            return ans;
        }
    
        //先验概率
        public double priorP(Object c){
            int num=0;
            for(Entity entity:entities){
                if (entity.getValue(classifyLabel).equals(c)){
                    num++;
                }
            }
            //return num*1.0/entities.size();
            return (num+1)*1.0/(entities.size()+classes.size());
        }
    
        //条件概率p(x|c)
        public double p(String label1,Object x,String label2,Object c){
            double ans=0;
            if(! (x instanceof Number)){
                double num1=0;
                double num2=0;
                for(Entity entity:entities){
                    Object value1 = entity.getValue(label1);
                    Object value2 = entity.getValue(label2);
                    //System.out.println(value1+" "+value2);
                    if(value2.equals(c)){
                        num2++;
                        if(value1.equals(x)){
                            num1++;
                        }
                    }
                }
                return (num1+1)/(num2+getN(label1));
            }else{
                List<Number>list=new ArrayList<Number>();
                for(Entity entity:entities){
                    Object value1 = entity.getValue(label1);
                    Object value2 = entity.getValue(label2);
                    if(value2.equals(c)){
                        list.add((Number) value1);
                    }
                }
                double average = getAverage(list);
                double deviation = getStandardDeviation(list);
                //System.out.println(average+" "+variance);
                Double number = ((Number) x).doubleValue();
                double t=-(number-average)*(number-average)/2/deviation/deviation;
    
                ans=1/Math.sqrt(2*Math.PI)/deviation*Math.exp(t);
                return ans;
            }
        }
    
        //获取label属性值的种类数量
        public double getN(String label1) {
           Set<Object> set = new HashSet<Object>();
            for(Entity entity:entities){
                set.add(entity.getValue(label1));
            }
            return set.size();
        }
    
        //获取标准差
        public double getStandardDeviation(List<Number>list){
            return Math.sqrt(getVariance(list));
        }
    
        //获取方差
        public double getVariance(List<Number>list){
            double ans=0;
            double average = getAverage(list);
            for(Number number:list){
                double t =((number.doubleValue()-average)*(number.doubleValue()-average));
                ans+=t;
            }
            return ans/(list.size()-1);
        }
    
        //获取平均数
        public double getAverage(List<Number>list){
           double ans=0;
            for(Number number:list){
                ans+=number.doubleValue();
            }
            return ans/list.size();
        }
    }
    
    
  3. 一些问题
    这里判断属性是否为离散用的 instance of Number判断,若处理数值型的标称属性会出现问题

posted on 2020-12-17 11:10  计网好难  阅读(178)  评论(0)    收藏  举报