A toy example of Reinforcement Learning (matlab code)
如下图所示:

假设我们有一个agent,有三个状态S = {s1,s2,s3},有三个操作A = {a1,a2,a3},给定每个状态下进行不同操作的奖励 R(s,a),如何进行Q-Learning? 下面是我给出的一个matla实现:
1 %% PART1: RULE DEFINITION 2 3 % S = {'s1', 's2', 's3'}; 4 % A = {'up','left','right'}; 5 % SS = {'s3','s1','s2';... 6 % 's3','s1','s2';... 7 % [], [], []}; 8 9 S = [1, 2, 3]; % state set 10 A = [1, 2, 3]; % action set 11 SS = [3, 1, 2;... % next state, namly s', transformed from state s given the selection of action a 12 3, 1, 2]; 13 14 R = [-0.2, -0.1, 0.2;... % reward for agent given state 's' and take action 'a' 15 0.1, -0.2, -0.1;... 16 [], [], []]; 17 18 %% PART2: RL using Q-LEARNING 19 nS = length(S); % # states 20 nA = length(A);% # actions 21 22 gamma = 0 : 0.2 : 1; % gamma list 23 epsilon = 1; % parameter in epsilon-greedy strategy 24 epsilon_decay = 0.9; % = 1 if do not conduct epsilon-greedy strategy 25 QQs = {}; % QQs{j} store Q-tables in every iteration with gamma = gamma(j) 26 maxIterations = 100; % max iterations for learning 27 28 for j = 1:length(gamma) 29 % Q-table; Q = rand(nS, nA) for full state set 30 Q = rand(nS-1, nA); 31 % statistics: store Q after every iteration update; QQ = zeros(nS-1, nA, maxIterations) for full state set 32 QQ = zeros(nS-1, nA, maxIterations); 33 for i = 1:maxIterations 34 s = S(randi(nS)); 35 while s ~= S(end) 36 if rand()<epsilon % this is AI here 37 a = A(randi(nA));% exploration: pick an action a, randomly 38 else % again, this is AI here 39 [~, best_id] = max(Q(s,:));% exploitation: a = argmax_a(s,a); 40 a = A(best_id); 41 end 42 ss = SS(s, a); 43 r = R(s, a); 44 if ss == S(end) 45 Q(s, a) = r; 46 else 47 Q(s, a) = r + gamma(j) * max(Q(ss,:)); 48 end 49 s = ss; 50 end 51 QQ(:,:,i) = Q; 52 epsilon = epsilon_decay * epsilon; 53 end 54 QQs{j} = QQ; 55 end 56 57 %% PART3: CONVERGENCE VISUALIZATION (Q-Table) 58 figure(1) 59 color = {'r.-', 'g.-', 'b.-';... 60 'c.-', 'm.-', 'k.-'}; 61 ui_row = 2; 62 ui_col = ceil(length(QQs)/ui_row); 63 for j = 1:length(QQs) 64 %check QQ>= R 65 disp('+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-') 66 disp(strcat('Q when gamma = ',num2str(gamma(j)),':')); 67 disp(QQs{j}(:,:,maxIterations)); 68 disp('R (Pre-determinated ):'); 69 disp(R); 70 disp('+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-') 71 % show convergence of QQ vs. different gamma 72 subplot(ui_row, ui_col, j); 73 for s = 1 : nS-1 74 for a = 1 : nA 75 plot(1:maxIterations, reshape(QQ(s,a,:), 1, maxIterations), color{s, a},'LineWidth',2, 'MarkerSize',10); 76 hold on; 77 end 78 end 79 legend('<s_1,a_1>','<s_1,a_2>','<s_1,a_3>','<s_2,a_1>','<s_2,a_2>','<s_2,a_3>'); 80 title(strcat('gamma = ', num2str(gamma(j)))); 81 hold off; 82 end
浙公网安备 33010602011771号