增强学习--Sarsa算法
1 import numpy as np 2 import random 3 from collections import defaultdict 4 from environment import Env 5 6 7 # SARSA agent learns every time step from the sample <s, a, r, s', a'> 8 class SARSAgent: 9 def __init__(self, actions): 10 self.actions = actions 11 self.learning_rate = 0.01 12 self.discount_factor = 0.9 13 self.epsilon = 0.1 14 self.q_table = defaultdict(lambda: [0.0, 0.0, 0.0, 0.0])#动作值函数表,q表,要更新的表,不同于mc的更新v表 15 16 # with sample <s, a, r, s', a'>, learns new q function 17 def learn(self, state, action, reward, next_state, next_action): 18 current_q = self.q_table[state][action] 19 next_state_q = self.q_table[next_state][next_action] 20 new_q = (current_q + self.learning_rate * 21 (reward + self.discount_factor * next_state_q - current_q))#q表更新公式 22 self.q_table[state][action] = new_q 23 24 # get action for the state according to the q function table 25 # agent pick action of epsilon-greedy policy 26 def get_action(self, state):#获取下一步动作 27 #epsilon-greedy policy,exploration 28 if np.random.rand() < self.epsilon: 29 # take random action 30 action = np.random.choice(self.actions) 31 else: 32 # take action according to the q function table 33 state_action = self.q_table[state] 34 action = self.arg_max(state_action) 35 return action 36 37 @staticmethod 38 def arg_max(state_action): 39 max_index_list = [] 40 max_value = state_action[0] 41 for index, value in enumerate(state_action): 42 if value > max_value: 43 max_index_list.clear() 44 max_value = value 45 max_index_list.append(index) 46 elif value == max_value: 47 max_index_list.append(index) 48 return random.choice(max_index_list) 49 50 if __name__ == "__main__": 51 env = Env() 52 agent = SARSAgent(actions=list(range(env.n_actions))) 53 54 for episode in range(1000): 55 # reset environment and initialize state 56 57 state = env.reset() 58 # get action of state from agent 59 action = agent.get_action(str(state)) 60 61 while True: 62 env.render() 63 64 # take action and proceed one step in the environment 65 next_state, reward, done = env.step(action) 66 next_action = agent.get_action(str(next_state)) 67 68 # with sample <s,a,r,s',a'>, agent learns new q function 69 agent.learn(str(state), action, reward, str(next_state), next_action) 70 71 state = next_state 72 action = next_action 73 74 # print q function of all states at screen 75 env.print_value_all(agent.q_table) 76 77 # if episode ends, then break 78 if done: 79 break
桔桔桔桔桔桔桔桔桔桔