一段有关线搜索的从python到matlab的代码

  在Udacity上很多关于机器学习的课程几乎都是基于python语言的,博主“ttang”的博文“重新发现梯度下降法——backtracking line search”里对回溯线搜索的算法实现也是用python写的,这对没有接触过python的我来说,内心是非常“抓狂”的。看到代码有想看到运行结果的冲动,暂时又不想去下载软件,好在这段代码简单、清晰,不信,你看原代码【1】

# -*- coding: cp936 -*-
#optimization test, y = (x-3)^2
from matplotlib.pyplot import figure, hold, plot, show, xlabel, ylabel, legend
def f(x):
        "The function we want to minimize"
        return (x-3)**2
def f_grad(x):
        "gradient of function f"
        return 2*(x-3)
x = 0
y = f(x)
err = 1.0
maxIter = 300
curve = [y]
it = 0
step = 0.1
#下面展示的是我之前用的方法,看上去貌似还挺合理的,但是很慢
while err > 1e-4 and it < maxIter:
    it += 1
    gradient = f_grad(x)
    new_x = x - gradient * step
    new_y = f(new_x)
    new_err = abs(new_y - y)
    if new_y > y: #如果出现divergence的迹象,就减小step size
        step *= 0.8
    err, x, y = new_err, new_x, new_y
    print 'err:', err, ', y:', y
    curve.append(y)

print 'iterations: ', it
figure(); hold(True); plot(curve, 'r*-')
xlabel('iterations'); ylabel('objective function value')

#下面展示的是backtracking line search,速度很快
x = 0
y = f(x)
err = 1.0
alpha = 0.25
beta = 0.8
curve2 = [y]
it = 0

while err > 1e-4 and it < maxIter:
    it += 1
    gradient = f_grad(x)
    step = 1.0
    while f(x - step * gradient) > y - alpha * step * gradient**2:
        step *= beta
    x = x - step * gradient
    new_y = f(x)
    err = y - new_y
    y = new_y
    print 'err:', err, ', y:', y
    curve2.append(y)

print 'iterations: ', it
plot(curve2, 'bo-')
legend(['gradient descent I used', 'backtracking line search'])
show()

  确实是为了观察实验结果,暂时又不想去装python,就把上面的代码改成了matlab code

% optimization test, y = (x-3)^2
%  -*- zw -*-
f=@(x)(x-3)^2;
diff_f=@(x)2*(x-3);
x = 0;
y = f(x);
err = 1.0;
maxIter = 300;
curve = [];
iter = 0;
step = 0.1;
% 下面展示的是我之前用的方法,看上去貌似还挺合理的,但是很慢
while err > 1e-4 && iter < maxIter
    iter=iter+ 1;
    gradient = diff_f(x);
    new_x = x - gradient * step;
    new_y = f(new_x);
    new_err = abs(new_y - y);
    if new_y > y
        % 如果出现divergence的迹象,就减小step size
        step =step* 0.8;
    end
    err=new_err;
    x=new_x;
    y=new_y;
    
    fprintf('iteration: %d, err: %f, y: %f \n',iter, err, y);
    curve(iter)=y;
end

figure(); axes('linewidth',1, 'box', 'on', 'FontSize',16);
hold on; plot(curve,
'r*-') xlabel('iterations'); ylabel('objective function value') % 下面展示的是backtracking line search,速度很快 x = 0; y = f(x); err = 1.0; alpha = 0.25; beta = 0.8; curve2 = []; iter = 0; while err > 1e-4 && iter < maxIter iter =iter+ 1; gradient = diff_f(x); step = 1.0; while f(x - step * gradient) > y - alpha * step * gradient^2 step =step* beta; end x = x - step * gradient; new_y = f(x); err = y - new_y; y = new_y; fprintf( 'iteration: %d, err: %f, y: %f \n', iter,err, y); curve2(iter)=y; end plot(curve2, 'bo-') legend('gradient descent I used', 'backtracking line search')
Matlab代码运行的结果

部分细节问题可以参考和对比原博文【1】。

 

参考:

【1】http://www.cnblogs.com/fstang/p/4192735.html

posted @ 2016-09-22 20:00  zhengw  阅读(1898)  评论(0)    收藏  举报