[Java] 数据分析 -- 回归分析

线性回归

  • 需求:从文件读取数据对,计算回归函数及系数
  • 实现1:commons.math的SimpleRegression,定义函数getData从文件读取数据返回SimpleRegression类
 1 import java.io.File;
 2 import java.io.FileNotFoundException;
 3 import java.util.Scanner;
 4 import org.apache.commons.math3.stat.regression.SimpleRegression;
 5 
 6 public class Example1 {
 7     public static void main(String[] args) {
 8         SimpleRegression sr = getData("data/Data1.dat");
 9         double m = sr.getSlope();
10         double b = sr.getIntercept();
11         double r = sr.getR();  // correlation coefficient
12         double r2 = sr.getRSquare();
13         double sse = sr.getSumSquaredErrors();
14         double tss = sr.getTotalSumSquares();
15 
16         System.out.printf("y = %.6fx + %.4f%n", m, b);
17         System.out.printf("r = %.6f%n", r);
18         System.out.printf("r2 = %.6f%n", r2);
19         System.out.printf("EV = %.5f%n", tss - sse);
20         System.out.printf("UV = %.4f%n", sse);
21         System.out.printf("TV = %.3f%n", tss);
22     }
23     
24     public static SimpleRegression getData(String data) {
25         SimpleRegression sr = new SimpleRegression();
26         try {
27             Scanner fileScanner = new Scanner(new File(data));
28             fileScanner.nextLine();  // read past title line
29             int n = fileScanner.nextInt();
30             fileScanner.nextLine();  // read past line of labels
31             fileScanner.nextLine();  // read past line of labels
32             for (int i = 0; i < n; i++) {
33                 String line = fileScanner.nextLine();
34                 Scanner lineScanner = new Scanner(line).useDelimiter("\\t");
35                 double x = lineScanner.nextDouble();
36                 double y = lineScanner.nextDouble();
37                 sr.addData(x, y);
38             }
39         } catch (FileNotFoundException e) {
40             System.err.println(e);
41         }
42         return sr;
43     }
44 }
View Code
  • 实现2:直接计算统计量
 1 import java.io.File;
 2 import java.io.FileNotFoundException;
 3 import java.util.Scanner;
 4 
 5 public class Example2 {
 6     private static double sX=0, sXX=0, sY=0, sYY=0, sXY=0;
 7     private static int n=0;
 8 
 9     public static void main(String[] args) {
10         getData("data/Data1.dat");
11         double m = (n*sXY - sX*sY)/(n*sXX - sX*sX);
12         double b = sY/n - m*sX/n;
13         double r2 = m*m*(n*sXX - sX*sX)/(n*sYY - sY*sY);
14         double r = Math.sqrt(r2);
15         double tv = sYY - sY*sY/n;
16         double mX = sX/n;  // mean value of x
17         double ev = (sXX - 2*mX*sX + n*mX*mX)*m*m;
18         double uv = tv - ev;
19         
20         System.out.printf("y = %.6fx + %.4f%n", m, b);
21         System.out.printf("r = %.6f%n", r);
22         System.out.printf("r2 = %.6f%n", r2);
23         System.out.printf("EV = %.5f%n", ev);
24         System.out.printf("UV = %.4f%n", uv);
25         System.out.printf("TV = %.3f%n", tv);
26     }
27     
28     public static void getData(String data) {
29         try {
30             Scanner fileScanner = new Scanner(new File(data));
31             fileScanner.nextLine();  // read past title line
32             n = fileScanner.nextInt();
33             fileScanner.nextLine();  // read past line of labels
34             fileScanner.nextLine();  // read past line of labels
35             for (int i = 0; i < n; i++) {
36                 String line = fileScanner.nextLine();
37                 Scanner lineScanner = new Scanner(line).useDelimiter("\\t");
38                 double x = lineScanner.nextDouble();
39                 double y = lineScanner.nextDouble();
40                 sX += x;
41                 sXX += x*x;
42                 sY += y;
43                 sYY += y*y;
44                 sXY += x*y;
45             }
46         } catch (FileNotFoundException e) {
47             System.err.println(e);
48         }
49     }
50 }
View Code

y = 0.882279x + 18.8739
r = 0.935222
r2 = 0.874641
EV = 1423.35676
UV = 204.0042
TV = 1627.361

  • 实现3:对辅助类进行实例化,并绘图

Example3.java

 1 import java.io.File;
 2 import javax.swing.JFrame;
 3 
 4 public class Example3 {
 5     public static void main(String[] args) {
 6         Data data = new Data(new File("data/Data1.dat"));
 7         JFrame frame = new JFrame(data.getTitle());
 8         frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
 9         RegressionPanel panel = new RegressionPanel(data);
10         frame.add(panel);
11         frame.pack();
12         frame.setSize(500, 422);
13         frame.setResizable(false);
14         frame.setLocationRelativeTo(null);  // center frame on screen
15         frame.setVisible(true);
16     }
17 }
View Code

Data.java

  1 import java.io.File;
  2 import java.io.FileNotFoundException;
  3 import java.util.Scanner;
  4 
  5 public class Data {
  6     private String title,xName, yName;
  7     private int n;
  8     private double[] x, y;
  9     private double sX, sXX, sY, sYY, sXY, minX, minY, maxX, maxY;
 10     private double meanX, meanY, slope, intercept, corrCoef;
 11 
 12     public Data(File inputFile) {
 13         try {
 14             Scanner input = new Scanner(inputFile);
 15             title = input.nextLine();
 16             n = input.nextInt();
 17             xName = input.next();
 18             yName = input.next();
 19             input.nextLine();
 20             x = new double[n];
 21             y = new double[n];
 22             minX = minY = Double.POSITIVE_INFINITY;
 23             maxX = maxY = Double.NEGATIVE_INFINITY;
 24             for (int i = 0; i < n; i++) {
 25                 double xi = x[i] = input.nextDouble();
 26                 double yi = y[i] = input.nextDouble();
 27                 sX += xi;
 28                 sXX += xi*xi;
 29                 sY += yi;
 30                 sYY += yi*yi;
 31                 sXY += xi*yi;
 32                 minX = (xi < minX? xi: minX);
 33                 minY = (yi < minY? yi: minY);
 34                 maxX = (xi > maxX? xi: maxX);
 35                 maxY = (yi > maxY? yi: maxY);
 36             }
 37             meanX = sX/n;
 38             meanY = sY/n;
 39             slope = (n*sXY - sX*sY)/(n*sXX - sX*sX);
 40             intercept = meanY - slope*meanX;
 41             corrCoef = slope*Math.sqrt((n*sXX - sX*sX)/(n*sYY - sY*sY));
 42         } catch (FileNotFoundException e) {
 43             System.err.println(e);
 44         }
 45     }
 46 
 47     public String getTitle() {
 48         return title;
 49     }
 50 
 51     public String getXName() {
 52         return xName;
 53     }
 54 
 55     public String getYName() {
 56         return yName;
 57     }
 58 
 59     public int getN() {
 60         return n;
 61     }
 62 
 63     public double[] getX() {
 64         return x;
 65     }
 66 
 67     public double[] getY() {
 68         return y;
 69     }
 70 
 71     public double getMeanX() {
 72         return meanX;
 73     }
 74 
 75     public double getMeanY() {
 76         return meanY;
 77     }
 78 
 79     public double getSlope() {
 80         return slope;
 81     }
 82 
 83     public double getIntercept() {
 84         return intercept;
 85     }
 86 
 87     public double getCorrCoef() {
 88         return corrCoef;
 89     }
 90     
 91     public double[][] getTable() {
 92         double[][] table = new double[n][2];
 93         for (int i = 0; i < n; i++) {
 94             table[i][0] = x[i];
 95             table[i][1] = y[i];
 96         }
 97         return table;
 98     }
 99 
100     public double getMinX() {
101         return minX;
102     }
103 
104     public double getMinY() {
105         return minY;
106     }
107 
108     public double getMaxX() {
109         return maxX;
110     }
111 
112     public double getMaxY() {
113         return maxY;
114     }
115 }
View Code

RegressionPanal.java

import java.awt.BasicStroke;
import java.awt.Color;
import java.awt.Graphics;
import java.awt.Graphics2D;
import javax.swing.JPanel;

public class RegressionPanel extends JPanel {
    private static final int WIDTH=500, HEIGHT=400, BUFFER=28, MARGIN=40;
    private final Data data;
    private double xMin, xMax, yMin, yMax, xRange, yRange, gWidth, gHeight;
    private double slope, intercept;

    public RegressionPanel(Data data) {
        this.data = data;
        this.setSize(WIDTH, HEIGHT);
        this.xMin = data.getMinX();
        this.xMax = data.getMaxX();
        this.yMin = data.getMinY();
        this.yMax = data.getMaxY();
        this.slope = data.getSlope();
        this.intercept = data.getIntercept();
        this.xRange = xMax - xMin;
        this.yRange = yMax - yMin;
        this.gWidth = WIDTH - 2*MARGIN - BUFFER;
        this.gHeight = HEIGHT - 2*MARGIN - BUFFER;
        setBackground(Color.WHITE);
    }
    
    @Override
    public void paintComponent(Graphics g) {
        super.paintComponent(g);
        Graphics2D g2 = (Graphics2D)g;
        g2.setStroke(new BasicStroke(1));
        drawGrid(g2);
        drawPoints(g2, data.getX(), data.getY());
        drawLine(g2);
    }

    private void drawGrid(Graphics2D g2) {
        g2.setStroke(new BasicStroke(1));
        double xGd = Math.pow(10, Math.floor(Math.log10(xRange)));
        int xd = dToI(xGd);
        int x0 = dToI(xGd*Math.floor(xMin/xGd));
        int xn = dToI(xGd*Math.ceil(xMax/xGd));
        for (int xi = x0; xi <= xn; xi += xd) {
            g2.setColor(Color.LIGHT_GRAY);
            int p = f(xi);
            g2.drawLine(p, 0, p, HEIGHT-18);  // vertical lines
            g2.setColor(Color.BLACK);
            g2.drawString(""+xi, p-8, HEIGHT-4);
        }
        double yGd = Math.pow(10, Math.floor(Math.log10(yRange)));
        int yd = dToI(yGd);
        int y0 = dToI(xGd*Math.floor(xMin/yGd));
        int yn = dToI(xGd*Math.ceil(yMax/yGd));
        for (int yi = y0; yi <= yn; yi += yd) {
            g2.setColor(Color.LIGHT_GRAY);
            int q = g(yi);
            g2.drawLine(BUFFER, q, WIDTH, q);  // horizontal lines
            g2.setColor(Color.LIGHT_GRAY);
            g2.setColor(Color.BLACK);
            g2.drawString((yi<100?"  ":"")+yi, 2, q+5);
        }
    }
    
    private void drawPoints(Graphics2D g2, double[] x, double[] y) {
        g2.setColor(Color.BLACK);
        for (int i = 0; i < x.length; i++) {
            int u = f(x[i]);
            int v = g(y[i]);
            g2.fillOval(u-3, v-3, 6, 6);  // coordinates are at NW corners
        }
    }
    
    private void drawLine(Graphics2D g2) {
        g2.setColor(Color.BLUE);
        g2.setStroke(new BasicStroke(2));
        int p0 = BUFFER;
        int q0 = g(yLine(fInv(p0)));
        int p1 = WIDTH;
        int q1 = g(yLine(fInv(p1)));
        g2.drawLine(p0, q0, p1, q1);
    }
    
    private double yLine(double x) {
        return slope*x + intercept;
    }
    
    private int dToI(double x) {
        return (int)Math.round(x);
    }
    
    private int f(double x) {
        return dToI((x - xMin)*gWidth/xRange) + BUFFER + MARGIN;
    }
    
    private int g(double y) {
        return dToI(gHeight - (y - yMin)*gHeight/yRange) + MARGIN;
    }
    
    private double fInv(int p) {
        return (p - BUFFER - MARGIN)*xRange/gWidth + xMin;
    }
    
    private double gInv(int q) {
        return yMin + (gHeight + MARGIN - q)*yRange/gHeight;
    }
}
View Code

 

 

 

多项式回归

  • 需求:已知刹车速度和距离的数据,求解
  • 实现:最小二乘法,解方程组,LU分解

 

 1 import org.apache.commons.math3.linear.*;
 2 
 3 public class Example4 {
 4     static double[] x = {20, 30, 40, 50, 60, 70};
 5     static double[] y = {52, 87, 136, 203, 290, 394};
 6     static int n = y.length;  // 6
 7 
 8     public static void main(String[] args) {
 9         double[][] a = new double[3][3];
10         double[] w = new double[3];
11         deriveNormalEquations(a, w);
12         printNormalEquations(a, w);
13         double[] b = solveNormalEquations(a, w);
14         printResults(b);
15     }
16 
17     public static void deriveNormalEquations(double[][] a, double[] w) {
18         for (int i = 0; i < n; i++) {
19             double xi = x[i];
20             double yi = y[i];
21             a[0][0] = n;
22             a[0][1] = a[1][0] += xi;
23             a[0][2] = a[1][1] = a[2][0] += xi*xi;
24             a[1][2] = a[2][1] += xi*xi*xi;
25             a[2][2] += xi*xi*xi*xi;
26             w[0] += yi;
27             w[1] += xi*yi;
28             w[2] += xi*xi*yi;
29         }
30     }
31 
32     public static void printNormalEquations(double[][] a, double[] w) {
33         for (int i = 0; i < 3; i++) {
34             System.out.printf("%8.0fb0 + %6.0fb1 + %8.0fb2 = %7.0f%n",
35                     a[i][0], a[i][1], a[i][2], w[i]);
36         }
37     }
38 
39     /*  Solves the matrix equation a*b = w for b[], representing a[] 
40         as RealMatrix m and b[] as RealVector v: 
41      */
42     private static double[] solveNormalEquations(double[][] a, double[] w) {
43             RealMatrix m = new Array2DRowRealMatrix(a, false);
44             LUDecomposition lud = new LUDecomposition(m);
45             DecompositionSolver solver = lud.getSolver();
46             RealVector v = new ArrayRealVector(w, false);
47             return solver.solve(v).toArray();
48     }
49     
50     private static void printResults(double[] b) {
51         System.out.printf("f(t) = %.2f + %.3ft + %.5ft^2%n", b[0], b[1], b[2]);
52         System.out.printf("f(55) = %.1f%n", f(55, b));
53     }
54     
55     private static double f(double t, double[] b) {
56         return b[0] + b[1]*t + b[2]*t*t;
57     }
58 }
View Code

6b0 + 270b1 + 13900b2 = 1162
270b0 + 13900b1 + 783000b2 = 64220
13900b0 + 783000b1 + 46750000b2 = 3798800
f(t) = 40.73 + -1.170t + 0.08875t^2
f(55) = 244.8

多元线性回归

  • 需求:变量y依赖于多个变量
  • 实现:直接求解或通过Apache Commons

Example5.java

 1 import org.apache.commons.math3.linear.*;
 2 
 3 public class Example5 {
 4     static double[] x = {10, 9, 12, 10, 9, 10, 8, 11};
 5     static double[] y = {59, 57, 61, 52, 48, 55, 51, 62};
 6     static double[] z = {71, 68, 76, 56, 57, 77, 55, 67};
 7     static int n = z.length;  // 8
 8             
 9     public static void main(String[] args) {
10         double[][] a = new double[3][3];
11         double[] w = new double[3];
12         deriveNormalEquations(a, w);
13         printNormalEquations(a, w);
14         double[] b = solveNormalEquations(a, w);
15         printResults(b);
16     }
17 
18     public static void deriveNormalEquations(double[][] a, double[] w) {
19         for (int i = 0; i < n; i++) {
20             double xi = x[i];
21             double yi = y[i];
22             double zi = z[i];
23             a[0][0] = n;
24             a[0][1] = a[1][0] += xi;
25             a[0][2] = a[2][0] += yi;
26             a[1][1] += xi*xi;
27             a[1][2] = a[2][1] += xi*yi;
28             a[2][2] += yi*yi;
29             w[0] += zi;
30             w[1] += xi*zi;
31             w[2] += yi*zi;
32         }
33     }
34 
35     public static void printNormalEquations(double[][] a, double[] w) {
36         for (int i = 0; i < 3; i++) {
37             System.out.printf("%6.0fx0 + %4.0fx1 + %5.0fx2 = %5.0f%n",
38                     a[i][0], a[i][1], a[i][2], w[i]);
39         }
40     }
41 
42     private static double[] solveNormalEquations(double[][] a, double[] w) {
43         RealMatrix m = new Array2DRowRealMatrix(a, false);
44         LUDecomposition lud = new LUDecomposition(m);
45         DecompositionSolver solver = lud.getSolver();
46         RealVector v = new ArrayRealVector(w, false);
47         return solver.solve(v).toArray();
48     }
49     
50     private static void printResults(double[] b) {
51         System.out.printf("f(s, t) = %.2f + %.2fs + %.2ft%n", b[0], b[1], b[2]);
52         System.out.printf("f(10, 59) = %.1f%n", f(10, 59, b));
53         System.out.printf("f(9, 57) = %.1f%n", f(9, 57, b));
54         System.out.printf("f(11, 64) = %.1f%n", f(11, 64, b));
55     }
56     
57     private static double f(double s, double t, double[] b) {
58         return b[0] + b[1]*s + b[2]*t;
59     }
60 }
View Code

 

Example6.java

 1 import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;
 2 
 3 public class Example6 {
 4     static double[][] x = { {10, 59}, {9, 57}, {12, 61}, {10, 52}, {9, 48}, 
 5             {10, 55}, {8, 51}, {11, 62} };
 6     static double[] y = {71, 68, 76, 56, 57, 77, 55, 67};
 7 
 8     public static void main(String[] args) {
 9         OLSMultipleLinearRegression mlr = new OLSMultipleLinearRegression();
10         mlr.newSampleData(y, x);
11         double[] b = mlr.estimateRegressionParameters();
12         printResults(b);
13     }
14     
15     private static void printResults(double[] b) {
16         System.out.printf("f(s, t) = %.2f + %.2fs + %.2ft%n", b[0], b[1], b[2]);
17         System.out.printf("f(10, 59) = %.1f%n", f(10, 59, b));
18         System.out.printf("f(9, 57) = %.1f%n", f(9, 57, b));
19         System.out.printf("f(11, 64) = %.1f%n", f(11, 64, b));
20     }
21     
22     private static double f(double s, double t, double[] b) {
23         return b[0] + b[1]*s + b[2]*t;
24     }
25 }
View Code

8x0 + 79x1 + 445x2 = 527
79x0 + 791x1 + 4427x2 = 5254
445x0 + 4427x1 + 24929x2 = 29543
f(s, t) = -5.75 + 1.55s + 1.01t
f(10, 59) = 69.5
f(9, 57) = 65.9
f(11, 64) = 76.1

posted @ 2021-04-22 08:34  cxc1357  阅读(783)  评论(0编辑  收藏  举报