一段有关线搜索的从python到matlab的代码
在Udacity上很多关于机器学习的课程几乎都是基于python语言的,博主“ttang”的博文“重新发现梯度下降法——backtracking line search”里对回溯线搜索的算法实现也是用python写的,这对没有接触过python的我来说,内心是非常“抓狂”的。看到代码有想看到运行结果的冲动,暂时又不想去下载软件,好在这段代码简单、清晰,不信,你看原代码【1】
# -*- coding: cp936 -*- #optimization test, y = (x-3)^2 from matplotlib.pyplot import figure, hold, plot, show, xlabel, ylabel, legend def f(x): "The function we want to minimize" return (x-3)**2 def f_grad(x): "gradient of function f" return 2*(x-3) x = 0 y = f(x) err = 1.0 maxIter = 300 curve = [y] it = 0 step = 0.1 #下面展示的是我之前用的方法,看上去貌似还挺合理的,但是很慢 while err > 1e-4 and it < maxIter: it += 1 gradient = f_grad(x) new_x = x - gradient * step new_y = f(new_x) new_err = abs(new_y - y) if new_y > y: #如果出现divergence的迹象,就减小step size step *= 0.8 err, x, y = new_err, new_x, new_y print 'err:', err, ', y:', y curve.append(y) print 'iterations: ', it figure(); hold(True); plot(curve, 'r*-') xlabel('iterations'); ylabel('objective function value') #下面展示的是backtracking line search,速度很快 x = 0 y = f(x) err = 1.0 alpha = 0.25 beta = 0.8 curve2 = [y] it = 0 while err > 1e-4 and it < maxIter: it += 1 gradient = f_grad(x) step = 1.0 while f(x - step * gradient) > y - alpha * step * gradient**2: step *= beta x = x - step * gradient new_y = f(x) err = y - new_y y = new_y print 'err:', err, ', y:', y curve2.append(y) print 'iterations: ', it plot(curve2, 'bo-') legend(['gradient descent I used', 'backtracking line search']) show()
确实是为了观察实验结果,暂时又不想去装python,就把上面的代码改成了matlab code
% optimization test, y = (x-3)^2 % -*- zw -*- f=@(x)(x-3)^2; diff_f=@(x)2*(x-3); x = 0; y = f(x); err = 1.0; maxIter = 300; curve = []; iter = 0; step = 0.1; % 下面展示的是我之前用的方法,看上去貌似还挺合理的,但是很慢 while err > 1e-4 && iter < maxIter iter=iter+ 1; gradient = diff_f(x); new_x = x - gradient * step; new_y = f(new_x); new_err = abs(new_y - y); if new_y > y % 如果出现divergence的迹象,就减小step size step =step* 0.8; end err=new_err; x=new_x; y=new_y; fprintf('iteration: %d, err: %f, y: %f \n',iter, err, y); curve(iter)=y; end figure(); axes('linewidth',1, 'box', 'on', 'FontSize',16);
hold on; plot(curve, 'r*-') xlabel('iterations'); ylabel('objective function value') % 下面展示的是backtracking line search,速度很快 x = 0; y = f(x); err = 1.0; alpha = 0.25; beta = 0.8; curve2 = []; iter = 0; while err > 1e-4 && iter < maxIter iter =iter+ 1; gradient = diff_f(x); step = 1.0; while f(x - step * gradient) > y - alpha * step * gradient^2 step =step* beta; end x = x - step * gradient; new_y = f(x); err = y - new_y; y = new_y; fprintf( 'iteration: %d, err: %f, y: %f \n', iter,err, y); curve2(iter)=y; end plot(curve2, 'bo-') legend('gradient descent I used', 'backtracking line search')
Matlab代码运行的结果

部分细节问题可以参考和对比原博文【1】。
参考:
【1】http://www.cnblogs.com/fstang/p/4192735.html

浙公网安备 33010602011771号