一、logistic算法的学习

  Logistic  线性回归

  用weka软件将系数保存到文件夹中,通过读取系数,用算法来实现分析数据的目的。

一、读取系数文件,返回系数(有分类/有数值)

coeffcient存放的就是分类系数或者数值系数。
public static void coeffcient(List<Object> coeffcient, String file) throws IOException {
        File inFile=new File(file);
        if(inFile.exists()){
            BufferedReader reader = null;
            try {
                reader = new BufferedReader(new FileReader(inFile));
                String inString = null;
                //遍历每一行
                String normalName="";//记录每次类别的名称
                Map<Double,Double> normal = new HashMap<Double,Double>();//记录类别的系数
                boolean flag = false;//记录上一次是否是类别
                while((inString = reader.readLine())!=null){
                    String x = inString.substring(0,inString.indexOf(" "));
                    String y = inString.substring(inString.indexOf(" ")).trim();
                    //思路:如果是只有两类或者是数值类型,则x里面没有'='号,如果是分类则x里面有'='号
                    //如果没有等号,则直接add to coeffcient
                    //如果有等号,则放入map集合
                    if(x.indexOf("=")<=0){//不含=号
                        if(flag){
                            coeffcient.add(Collections.synchronizedMap(normal));
                            normal = new HashMap<Double,Double>();//重置map集合
                            flag=false;
                        }
                        coeffcient.add(Double.parseDouble(y));
                    }else{//如果包含=号
                        
                        String thisName = x.substring(0,x.indexOf("="));
                        if(!thisName.equals(normalName)&&!normalName.equals("")){//新的类别并且不是第一次
                             //将上次类别的集合放入list
                            if(flag){
                                coeffcient.add(Collections.synchronizedMap(normal));
                            }
                            normal = new HashMap<Double,Double>();//重置map集合
                            normalName=thisName;//记录这次的类别名称
                            normal.put(Double.parseDouble(x.substring(x.indexOf("=")+1)), Double.parseDouble(y));
                        }else{
                            //不是新的类别或者第一次
                            normalName=thisName;
                            normal.put(Double.parseDouble(x.substring(x.indexOf("=")+1)), Double.parseDouble(y));
                        }
                        flag=true;
                    }
                }
            } catch (Exception e) {
                e.printStackTrace();
            }finally{
                reader.close();
            }
        }else{
            System.out.println("文件不存在!");
            return;
        }
    }

二、计算概率的第一步

 如果是数值类型,则系数*数值相加,如果是分类变量,则直接系数相加。

/**
     * 1.数值类别或者两类的乘积和+分类类别的系数=我们想要的数据
     * @param coeffcient
     * @param data
     * @return
     */
    public static double calCoeff(List<Object> coeffcient, double[] data) {
        if(data.length!=(coeffcient.size()-1)){
            System.out.println("样本系数的类别数:"+(coeffcient.size()-1));
            System.out.println("您输入的数据文件与样本系数文件不匹配,无法分析!");
        }
        double z=0.0;
        for(int i=0;i<data.length;i++){
            Object c = coeffcient.get(i);
            if(c.getClass().getCanonicalName().equals("java.lang.Double")){
                //如果是double
                z=z+data[i]*(Double)c;
            }
            if(c.getClass().getCanonicalName().equals("java.util.Collections.SynchronizedMap")){
                @SuppressWarnings("unchecked")
                //如果是map
                Map<Double,Double> map = (Map<Double, Double>) c;
                z=z+map.get(data[i]);
            }
        }
        return z+(Double)coeffcient.get(data.length);
    }

 三、计算概率的第三步

public static double logistic(double z_middle_result) {
        double z_temp = 1 + Math.pow(Math.E, -z_middle_result);
        double p = 1 / z_temp;
        return p;
    }

 

 四、也可以读取csv文件里面的数据来测试,需要jar包opencsv-3.8.jar

public static double[][] getDouble(String fileData) throws BiffException, IOException {
        //创建excel对象
        File file = new File(fileData);  
        CSVReader reader = new CSVReader(new FileReader(file));
        List<String []> all=  reader.readAll();
        double [][] data = new double[all.size()][all.get(0).length];
        for(int i =0 ;i< all.size();i++){
            String [] line = all.get(i);
            for(int j=0;j<line.length;j++ ){
                data[i][j] =  Double.parseDouble(line[j]);
            }
        }
        return data;
    }

 五、用取到的csv的数据进行测试

public static void outResult(double[][] data, String outfile, List<Object> coeffcient) throws IOException {
        int i=0;//记录行数
        int num = 0;//记录数据输出的txt文件
        File file = null;
        BufferedWriter writer = null;
        boolean flag = true;
        for (i = 0; i < data.length; i++) {
            if(flag){
                file = new File(outfile+num+".txt");
                 writer = new BufferedWriter(new FileWriter(file));
                 flag=false;
            }
            if(i/10000!=num){
                writer.close();
                num=i/10000;
                file = new File(outfile+num+".txt");
                writer = new BufferedWriter(new FileWriter(file));
            }
            //产生系数之后通过遍历List得到这组数据对应的系数
            double z_middle_result = CalCoeff.calCoeff(coeffcient, data[i]);
            //得到对应的系数之后通过logistic算法得到最后的结果
            double result = Logistic.logistic(z_middle_result);
            writer.write("第" + (i + 1) + "行结果:" + result);
            writer.newLine();
            System.out.println("第" + (i + 1) + "行结果:" + result);
        }
        
        try {
            
        } catch (Exception e) {
            System.out.println("输出结果出错!");
        }
        
    }

 

 六、通过Logistic的模型来分析数据

  1.创建模型并保存

public static void CreateLogistic(String file,String toFilePath) throws Throwable {
        ArffLoader atf = new ArffLoader();//arff文件加载器        
        File inputFile = new File(file);
        atf.setFile(inputFile);
        System.out.println("读取文件成功");
        Instances instancesTrain = atf.getDataSet(); //获取每一行的实例
        instancesTrain.setClassIndex(instancesTrain.numAttributes()-1);//设置分析的列(即结果列)
        System.out.println("设置分析结果成功");
        Logistic m_classifier=new Logistic();//Logistic算法对象
        
        m_classifier.buildClassifier(instancesTrain);//Logistic加载数据集
        System.out.println("加载数据集成功 ");
        ObjectOutputStream oos =null;
        try {
            oos = new ObjectOutputStream(new FileOutputStream(toFilePath));
        } catch (Exception e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        oos.writeObject(m_classifier);
        oos.flush();
        oos.close();
        System.out.println(toFilePath+"::保存成功!");
    }

 

   2.取出模型对象

public static Logistic getLogistic(String modelFilePath) throws Exception{
         Logistic m_classifier=new Logistic();
         m_classifier = (Logistic) weka.core.SerializationHelper.read(modelFilePath);
         return m_classifier;
    }

  3.用模型对象测试 

 

public static double getForecast(String model,double [] data) throws Exception {
        Logistic m_classifier = LogisticModel.getLogistic(model);
        Instance instance = new Instance(data.length);
        
        for (int j = 0; j < data.length; j++) {
            instance.setValue(j, data[j]);
        }
        double [] result = m_classifier.distributionForInstance(instance);
        return result[1];
    }

 

posted @ 2016-08-05 17:19  博智星  Views(391)  Comments(0)    收藏  举报