神经网络作业: NN LEARNING Coursera Machine Learning(Andrew Ng) WEEK 5

在WEEK 5中,作业要求完成通过神经网络(NN)实现多分类的逻辑回归(MULTI-CLASS LOGISTIC REGRESSION)的监督学习(SUOERVISED LEARNING)来识别阿拉伯数字。作业主要目的是感受如何在NN中求代价函数(COST FUNCTION)和其假设函数中各个参量(THETA)的求导值(GRADIENT DERIVATIVE)(利用BACKPROPAGGATION)。

难度不高,但问题是你要习惯使用MATLAB的矩阵QAQ,作为一名蒟蒻,我已经狗带了。以下代核心部分的代码希望给被作业卡住的同学一些帮助。但请不要照搬代码哦~不要~不要~

 1 ty = zeros(m, num_labels);
 2 
 3 for i=1:m
 4     for j=1:num_labels
 5         if y(i)==j
 6             ty(i,j) = 1;
 7         end
 8     end
 9 end
10 
11 a1 = X;
12 a1 = [ones(size(a1,1),1) a1];
13 z2 = a1 * Theta1';
14 a2 = sigmoid(z2);
15 a2 = [ones(size(a2,1),1) a2];
16 z3 = a2 * Theta2';
17 a3 = sigmoid(z3);
18 
19 
20 for i=1:m
21     for j=1:num_labels
22         J = J - log(1-a3(i,j))*(1-ty(i,j))/m-log(a3(i,j))*ty(i,j)/m;
23     end
24 end
25 
26 %size(J,1)
27 %size(J,2)
28     
29 d3 = a3 - ty;
30 d2 = (d3 * Theta2(:,2:end)).*sigmoidGradient(z2);
31 Theta1_grad = Theta1_grad + d2'*a1/m;
32 Theta2_grad = Theta2_grad + d3'*a2/m;
33 
34 % -------------------------------------------------------------
35 JJ=0;
36 
37  for i=1:size(Theta1,1)
38     for j=2:size(Theta1,2)
39                 JJ = JJ + Theta1(i,j)*Theta1(i,j)*lambda/(m*2);
40     end
41  end
42  size(Theta1,1);
43  size(Theta1,2);
44  
45  for i=1:size(Theta2,1)
46        for j=2:size(Theta2,2)
47            JJ = JJ + Theta2(i,j)*Theta2(i,j)*lambda/(2*m);
48       end
49 end
50 size(Theta2,1);
51 size(Theta2,2);
52 %J = J + (lambda/(2*m)) * (Theta1(:,2:end).*Theta1(:,2:end)+Theta2(2:end,:).*Theta2(2:end,:));
53 J =J+JJ;
54 
55 Theta1_gradd = zeros(size(Theta1));
56 Theta2_gradd = zeros(size(Theta2));
57 
58 for i=2:size(Theta1,2)
59     for j=1:size(Theta1,1)
60         Theta1_gradd(j,i) = Theta1(j,i)*lambda/m;
61     end
62 end
63 
64 for i=2:size(Theta2,2)
65     for j=1:size(Theta2,1)
66         Theta2_gradd(j,i) = Theta2(j,i)*lambda/m;
67     end
68 end
69 
70 Theta1_grad = Theta1_gradd+Theta1_grad;
71 Theta2_grad = Theta2_gradd+Theta2_grad;

PS:博主蒟蒻强迫自己下次要写矩阵运算,不能再套循环啦!!!

 

posted @ 2016-01-31 22:31  sllr15  阅读(351)  评论(2编辑  收藏  举报