一、logistic算法的学习
Logistic 线性回归
用weka软件将系数保存到文件夹中,通过读取系数,用算法来实现分析数据的目的。
一、读取系数文件,返回系数(有分类/有数值)
coeffcient存放的就是分类系数或者数值系数。
public static void coeffcient(List<Object> coeffcient, String file) throws IOException { File inFile=new File(file); if(inFile.exists()){ BufferedReader reader = null; try { reader = new BufferedReader(new FileReader(inFile)); String inString = null; //遍历每一行 String normalName="";//记录每次类别的名称 Map<Double,Double> normal = new HashMap<Double,Double>();//记录类别的系数 boolean flag = false;//记录上一次是否是类别 while((inString = reader.readLine())!=null){ String x = inString.substring(0,inString.indexOf(" ")); String y = inString.substring(inString.indexOf(" ")).trim(); //思路:如果是只有两类或者是数值类型,则x里面没有'='号,如果是分类则x里面有'='号 //如果没有等号,则直接add to coeffcient //如果有等号,则放入map集合 if(x.indexOf("=")<=0){//不含=号 if(flag){ coeffcient.add(Collections.synchronizedMap(normal)); normal = new HashMap<Double,Double>();//重置map集合 flag=false; } coeffcient.add(Double.parseDouble(y)); }else{//如果包含=号 String thisName = x.substring(0,x.indexOf("=")); if(!thisName.equals(normalName)&&!normalName.equals("")){//新的类别并且不是第一次 //将上次类别的集合放入list if(flag){ coeffcient.add(Collections.synchronizedMap(normal)); } normal = new HashMap<Double,Double>();//重置map集合 normalName=thisName;//记录这次的类别名称 normal.put(Double.parseDouble(x.substring(x.indexOf("=")+1)), Double.parseDouble(y)); }else{ //不是新的类别或者第一次 normalName=thisName; normal.put(Double.parseDouble(x.substring(x.indexOf("=")+1)), Double.parseDouble(y)); } flag=true; } } } catch (Exception e) { e.printStackTrace(); }finally{ reader.close(); } }else{ System.out.println("文件不存在!"); return; } }
二、计算概率的第一步
如果是数值类型,则系数*数值相加,如果是分类变量,则直接系数相加。
/** * 1.数值类别或者两类的乘积和+分类类别的系数=我们想要的数据 * @param coeffcient * @param data * @return */ public static double calCoeff(List<Object> coeffcient, double[] data) { if(data.length!=(coeffcient.size()-1)){ System.out.println("样本系数的类别数:"+(coeffcient.size()-1)); System.out.println("您输入的数据文件与样本系数文件不匹配,无法分析!"); } double z=0.0; for(int i=0;i<data.length;i++){ Object c = coeffcient.get(i); if(c.getClass().getCanonicalName().equals("java.lang.Double")){ //如果是double z=z+data[i]*(Double)c; } if(c.getClass().getCanonicalName().equals("java.util.Collections.SynchronizedMap")){ @SuppressWarnings("unchecked") //如果是map Map<Double,Double> map = (Map<Double, Double>) c; z=z+map.get(data[i]); } } return z+(Double)coeffcient.get(data.length); }
三、计算概率的第三步
public static double logistic(double z_middle_result) { double z_temp = 1 + Math.pow(Math.E, -z_middle_result); double p = 1 / z_temp; return p; }
四、也可以读取csv文件里面的数据来测试,需要jar包opencsv-3.8.jar
public static double[][] getDouble(String fileData) throws BiffException, IOException { //创建excel对象 File file = new File(fileData); CSVReader reader = new CSVReader(new FileReader(file)); List<String []> all= reader.readAll(); double [][] data = new double[all.size()][all.get(0).length]; for(int i =0 ;i< all.size();i++){ String [] line = all.get(i); for(int j=0;j<line.length;j++ ){ data[i][j] = Double.parseDouble(line[j]); } } return data; }
五、用取到的csv的数据进行测试
public static void outResult(double[][] data, String outfile, List<Object> coeffcient) throws IOException { int i=0;//记录行数 int num = 0;//记录数据输出的txt文件 File file = null; BufferedWriter writer = null; boolean flag = true; for (i = 0; i < data.length; i++) { if(flag){ file = new File(outfile+num+".txt"); writer = new BufferedWriter(new FileWriter(file)); flag=false; } if(i/10000!=num){ writer.close(); num=i/10000; file = new File(outfile+num+".txt"); writer = new BufferedWriter(new FileWriter(file)); } //产生系数之后通过遍历List得到这组数据对应的系数 double z_middle_result = CalCoeff.calCoeff(coeffcient, data[i]); //得到对应的系数之后通过logistic算法得到最后的结果 double result = Logistic.logistic(z_middle_result); writer.write("第" + (i + 1) + "行结果:" + result); writer.newLine(); System.out.println("第" + (i + 1) + "行结果:" + result); } try { } catch (Exception e) { System.out.println("输出结果出错!"); } }
六、通过Logistic的模型来分析数据
1.创建模型并保存
public static void CreateLogistic(String file,String toFilePath) throws Throwable { ArffLoader atf = new ArffLoader();//arff文件加载器 File inputFile = new File(file); atf.setFile(inputFile); System.out.println("读取文件成功"); Instances instancesTrain = atf.getDataSet(); //获取每一行的实例 instancesTrain.setClassIndex(instancesTrain.numAttributes()-1);//设置分析的列(即结果列) System.out.println("设置分析结果成功"); Logistic m_classifier=new Logistic();//Logistic算法对象 m_classifier.buildClassifier(instancesTrain);//Logistic加载数据集 System.out.println("加载数据集成功 "); ObjectOutputStream oos =null; try { oos = new ObjectOutputStream(new FileOutputStream(toFilePath)); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } oos.writeObject(m_classifier); oos.flush(); oos.close(); System.out.println(toFilePath+"::保存成功!"); }
2.取出模型对象
public static Logistic getLogistic(String modelFilePath) throws Exception{ Logistic m_classifier=new Logistic(); m_classifier = (Logistic) weka.core.SerializationHelper.read(modelFilePath); return m_classifier; }
3.用模型对象测试
public static double getForecast(String model,double [] data) throws Exception { Logistic m_classifier = LogisticModel.getLogistic(model); Instance instance = new Instance(data.length); for (int j = 0; j < data.length; j++) { instance.setValue(j, data[j]); } double [] result = m_classifier.distributionForInstance(instance); return result[1]; }