A toy example of Reinforcement Learning (matlab code)

 如下图所示:

假设我们有一个agent,有三个状态S = {s1,s2,s3},有三个操作A = {a1,a2,a3},给定每个状态下进行不同操作的奖励 R(s,a),如何进行Q-Learning? 下面是我给出的一个matla实现:

 1 %% PART1: RULE DEFINITION
 2 
 3 % S = {'s1', 's2', 's3'};
 4 % A = {'up','left','right'};
 5 % SS = {'s3','s1','s2';...
 6 %           's3','s1','s2';...
 7 %              [],   [],  []};
 8 
 9 S   = [1, 2, 3];  % state set
10 A   = [1, 2, 3];  % action set
11 SS = [3, 1, 2;...  % next state, namly s', transformed from state s given the selection of action a
12           3, 1, 2];
13 
14 R = [-0.2,  -0.1,   0.2;... % reward for agent given state 's' and take action 'a'
15           0.1, -0.2,  -0.1;...
16             [],     [],   []];
17 
18 %% PART2: RL using Q-LEARNING
19 nS = length(S); % # states
20 nA = length(A);% # actions
21 
22 gamma = 0 : 0.2 : 1;    % gamma list
23 epsilon = 1;                  % parameter in epsilon-greedy strategy 
24 epsilon_decay = 0.9;   % = 1 if do not conduct epsilon-greedy strategy 
25 QQs = {};                       % QQs{j} store Q-tables in every iteration with gamma = gamma(j) 
26 maxIterations = 100;  % max iterations for learning
27 
28 for j = 1:length(gamma)
29     % Q-table; Q = rand(nS, nA) for full state set
30     Q = rand(nS-1, nA); 
31     % statistics: store Q after every iteration update; QQ = zeros(nS-1, nA, maxIterations) for full state set
32     QQ = zeros(nS-1, nA, maxIterations); 
33     for i = 1:maxIterations
34         s = S(randi(nS));
35         while s ~= S(end)
36             if rand()<epsilon % this is AI here 
37                 a = A(randi(nA));% exploration: pick an action a, randomly
38             else                        % again, this is AI here 
39                 [~, best_id] = max(Q(s,:));% exploitation:  a = argmax_a(s,a);
40                 a = A(best_id);
41             end
42             ss = SS(s, a);
43             r = R(s, a);
44             if ss == S(end)
45                 Q(s, a) = r;
46             else
47                 Q(s, a) = r + gamma(j) * max(Q(ss,:));
48             end
49             s = ss;
50         end
51         QQ(:,:,i) = Q;
52         epsilon = epsilon_decay * epsilon;
53     end
54     QQs{j} = QQ;
55 end
56 
57 %% PART3: CONVERGENCE VISUALIZATION (Q-Table)
58 figure(1)
59 color = {'r.-', 'g.-', 'b.-';...
60                'c.-', 'm.-', 'k.-'};
61 ui_row = 2;
62 ui_col = ceil(length(QQs)/ui_row);
63 for j = 1:length(QQs)
64     %check QQ>= R
65     disp('+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-')
66     disp(strcat('Q when gamma = ',num2str(gamma(j)),':'));
67     disp(QQs{j}(:,:,maxIterations));
68     disp('R (Pre-determinated ):');
69     disp(R);
70     disp('+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-')
71     % show convergence of QQ vs. different gamma
72     subplot(ui_row, ui_col, j);
73     for s = 1 : nS-1
74         for a = 1 : nA
75           plot(1:maxIterations, reshape(QQ(s,a,:), 1, maxIterations), color{s, a},'LineWidth',2, 'MarkerSize',10); 
76           hold on;
77         end
78     end
79     legend('<s_1,a_1>','<s_1,a_2>','<s_1,a_3>','<s_2,a_1>','<s_2,a_2>','<s_2,a_3>');
80     title(strcat('gamma = ', num2str(gamma(j))));
81     hold off;
82 end

posted on 2016-05-13 13:29  sunscone  阅读(547)  评论(0)    收藏  举报

导航