TensorFlow实现梯度下降

 1 # -*- coding: utf-8 -*-
 2 """
 3 Created on Mon Oct 15 17:38:39 2018
 4 
 5 @author: zhen
 6 """
 7 
 8 import tensorflow as tf
 9 import numpy as np
10 from sklearn.datasets import fetch_california_housing
11 from sklearn.preprocessing import StandardScaler
12 
13 n_epochs = 10000
14 learning_rate = 0.01
15 
16 housing = fetch_california_housing(data_home="C:/Users/zhen/.spyder-py3/data", download_if_missing=True)
17 m, n = housing.data.shape
18 housing_data_plus_bias = np.c_[np.ones((m, 1)), housing.data]
19 # 归一化
20 scaler= StandardScaler().fit(housing_data_plus_bias)
21 scaled_housing_data_plus_bias = scaler.transform(housing_data_plus_bias)
22 # 创建常量
23 x = tf.constant(scaled_housing_data_plus_bias, dtype=tf.float32, name='x')
24 y = tf.constant(housing.target.reshape(-1, 1), dtype=tf.float32, name='y')
25 # 创建随机数
26 theta = tf.Variable(tf.random_uniform([n + 1, 1], -1.0, 1.0), name='theta')
27 # 矩阵乘
28 y_pred = tf.matmul(x, theta, name="predictions") 
29 
30 error = y_pred - y
31 # 求平均值
32 mse = tf.reduce_mean(tf.square(error), name="mse")
33 """
34 # 求梯度
35 gradients = tf.gradients(mse, [theta])[0]
36 # 赋值
37 training_op = tf.assign(theta, theta - learning_rate * gradients)
38 """
39 # 梯度下降
40 optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
41 training_op = optimizer.minimize(mse)
42 
43 init = tf.global_variables_initializer()
44 
45 with tf.Session() as sess:
46     sess.run(init)
47     
48     for epoch in range(n_epochs):
49         if epoch % 100 == 0:
50             print("Epoch", epoch, "MSE = ", mse.eval())
51         sess.run(training_op)
52         
53     best_theta = theta.eval()
54     print(best_theta)

结果:

Epoch 0 MSE =  9.128207
Epoch 100 MSE =  4.893214
Epoch 200 MSE =  4.8329406
Epoch 300 MSE =  4.824335
Epoch 400 MSE =  4.8187895
Epoch 500 MSE =  4.814753
Epoch 600 MSE =  4.811796
Epoch 700 MSE =  4.8096204
Epoch 800 MSE =  4.808017
Epoch 900 MSE =  4.806835
Epoch 1000 MSE =  4.805955
Epoch 1100 MSE =  4.805301
Epoch 1200 MSE =  4.8048124
Epoch 1300 MSE =  4.804449
Epoch 1400 MSE =  4.804172
Epoch 1500 MSE =  4.803962
Epoch 1600 MSE =  4.8038034
Epoch 1700 MSE =  4.803686
Epoch 1800 MSE =  4.8035927
Epoch 1900 MSE =  4.80352
Epoch 2000 MSE =  4.8034678
Epoch 2100 MSE =  4.803425
Epoch 2200 MSE =  4.8033857
Epoch 2300 MSE =  4.803362
Epoch 2400 MSE =  4.803341
Epoch 2500 MSE =  4.8033247
Epoch 2600 MSE =  4.80331
Epoch 2700 MSE =  4.8033013
Epoch 2800 MSE =  4.8032923
Epoch 2900 MSE =  4.8032856
Epoch 3000 MSE =  4.8032804
Epoch 3100 MSE =  4.803273
Epoch 3200 MSE =  4.803271
Epoch 3300 MSE =  4.8032694
Epoch 3400 MSE =  4.803267
Epoch 3500 MSE =  4.8032637
Epoch 3600 MSE =  4.8032603
Epoch 3700 MSE =  4.803259
Epoch 3800 MSE =  4.803259
Epoch 3900 MSE =  4.8032584
Epoch 4000 MSE =  4.8032575
Epoch 4100 MSE =  4.8032575
Epoch 4200 MSE =  4.803256
Epoch 4300 MSE =  4.803255
Epoch 4400 MSE =  4.803256
Epoch 4500 MSE =  4.803256
Epoch 4600 MSE =  4.803253
Epoch 4700 MSE =  4.8032565
Epoch 4800 MSE =  4.803258
Epoch 4900 MSE =  4.8032556
Epoch 5000 MSE =  4.803256
Epoch 5100 MSE =  4.8032537
Epoch 5200 MSE =  4.8032565
Epoch 5300 MSE =  4.803255
Epoch 5400 MSE =  4.8032546
Epoch 5500 MSE =  4.803254
Epoch 5600 MSE =  4.8032537
Epoch 5700 MSE =  4.8032517
Epoch 5800 MSE =  4.8032527
Epoch 5900 MSE =  4.8032537
Epoch 6000 MSE =  4.803254
Epoch 6100 MSE =  4.8032546
Epoch 6200 MSE =  4.803255
Epoch 6300 MSE =  4.8032546
Epoch 6400 MSE =  4.803253
Epoch 6500 MSE =  4.803253
Epoch 6600 MSE =  4.803253
Epoch 6700 MSE =  4.8032517
Epoch 6800 MSE =  4.803252
Epoch 6900 MSE =  4.8032517
Epoch 7000 MSE =  4.803252
Epoch 7100 MSE =  4.8032537
Epoch 7200 MSE =  4.8032537
Epoch 7300 MSE =  4.803253
Epoch 7400 MSE =  4.803253
Epoch 7500 MSE =  4.803253
Epoch 7600 MSE =  4.803254
Epoch 7700 MSE =  4.8032546
Epoch 7800 MSE =  4.8032556
Epoch 7900 MSE =  4.803256
Epoch 8000 MSE =  4.8032565
Epoch 8100 MSE =  4.8032565
Epoch 8200 MSE =  4.8032565
Epoch 8300 MSE =  4.8032556
Epoch 8400 MSE =  4.8032565
Epoch 8500 MSE =  4.8032575
Epoch 8600 MSE =  4.8032565
Epoch 8700 MSE =  4.803256
Epoch 8800 MSE =  4.803256
Epoch 8900 MSE =  4.8032556
Epoch 9000 MSE =  4.803255
Epoch 9100 MSE =  4.8032546
Epoch 9200 MSE =  4.803254
Epoch 9300 MSE =  4.8032546
Epoch 9400 MSE =  4.8032546
Epoch 9500 MSE =  4.803255
Epoch 9600 MSE =  4.803255
Epoch 9700 MSE =  4.803255
Epoch 9800 MSE =  4.803255
Epoch 9900 MSE =  4.803255
[[ 0.43350863]
 [ 0.8296331 ]
 [ 0.11875448]
 [-0.26555073]
 [ 0.3057157 ]
 [-0.00450223]
 [-0.03932685]
 [-0.8998542 ]
 [-0.87051094]]
View Code

结果样例:

 

posted @ 2018-10-17 10:16  云山之巅  阅读(1217)  评论(0编辑  收藏  举报