BGD SGD MBGD 与梯度下降算法 的理解

梯度下降算法,用于求函数 极小值,以及极小值点对应的 变量值。可以是线性,非线性等。

def fun(x,y):

    return 3*x*x+2*y*y


def step(x,y):

# 定义学习率a

    a=0.02

    x=x-a*(6*x)

    y=y-a*(4*y)

    return x,y


不断调用step就可以求出函数的极小值以及极小值x,y的值。


BGD SGD MBGD 梯度下降算法针对线性回归,其损失函数极小值就是最小值。

J(θ1,θ2,θ3) 是包含θ1,θ2,θ3,以及x,y(数据点)组成的函数。已知数据点xy,求合适的θ值,使函数J最小。

用θ的梯度去更新θ的值。

BGD就是传统的梯度下降,损失函数J就是使用了全部的xy数据(函数J里的Σ包含所有的xy数据),SGD去掉了Σ(单个数据),MBGD的Σ比BGD的Σ少(使用部分xy数据)。



BGD SGD MBGD 梯度下降算法 线性回归 python 实现


import matplotlib.pyplot as plt
import numpy as np

x=np.random.random(100)
y=5+3*x+np.random.normal(0,0.2,100)


def fun(theta1,theta2,x):
    return theta1+theta2*x


# BGD

theta1=0
theta2=0
a=0.2

#画图用
BGD_path1=[]
BGD_path2=[]

for i in range(1000):
    temp1=theta1-a*(sum((fun(theta1,theta2,x)-y))/100)
    temp2=theta2-a*(sum(x*(fun(theta1,theta2,x)-y))/100)
    theta1=temp1
    theta2=temp2
    print(theta1,theta2)
    BGD_path1.append(theta1)
    BGD_path2.append(theta2)

# plt.scatter(x,y)
# plt.plot(x,fun(theta1,theta2,x))
# plt.show()

# SGD

theta1=0
theta2=0
a=0.2

#画图用
SGD_path1=[]
SGD_path2=[]

for i in range(1):
    for j in range(100):
        print(j)
        temp1=theta1-a*(fun(theta1,theta2,x[j])-y[j])
        temp2=theta2-a*(x[j]*(fun(theta1,theta2,x[j])-y[j]))
        theta1=temp1
        theta2=temp2
        print(theta1,theta2)
        SGD_path1.append(theta1)
        SGD_path2.append(theta2)

# plt.scatter(x,y)
# plt.plot(x,fun(theta1,theta2,x))
# plt.show()

# MBGD 
# mini-batch = 10 

theta1=0
theta2=0
a=0.2

#画图用
MBGD_path1=[]
MBGD_path2=[]

for i in range(100):
    for j in range(10):
        temp1=theta1-a*(sum((fun(theta1,theta2,x[j*10:j*10+10])-y[j*10:j*10+10]))/100)
        temp2=theta2-a*(sum(x[j*10:j*10+10]*(fun(theta1,theta2,x[j*10:j*10+10])-y[j*10:j*10+10]))/100)
        theta1=temp1
        theta2=temp2
        print(theta1,theta2)
        MBGD_path1.append(theta1)
        MBGD_path2.append(theta2)

# plt.scatter(x,y)
# plt.plot(x,fun(theta1,theta2,x))
# plt.show()

from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = Axes3D(fig)
ax.plot(BGD_path1,BGD_path2,1,label='BGD')
ax.plot(SGD_path1,SGD_path2,2,label='SGD')
ax.plot(MBGD_path1,MBGD_path2,3,label='MBGD')
plt.legend()
plt.show()
posted @ 2020-05-17 13:52  雪夜羽  阅读(82)  评论(0编辑  收藏