增强学习--策略迭代
1 class PolicyIteration: 2 def __init__(self, env): 3 self.env = env 4 # 2-d list for the value function 5 self.value_table = [[0.0] * env.width for _ in range(env.height)]#值函数表 6 # list of random policy (same probability of up, down, left, right) 7 self.policy_table = [[[0.25, 0.25, 0.25, 0.25]] * env.width 8 for _ in range(env.height)]#每一状态的动作策略表,一开始向四方运动是相同概率的 9 # setting terminal state 10 self.policy_table[2][2] = []#吸收态,终止 11 self.discount_factor = 0.9 12 13 def policy_evaluation(self):#策略估计 14 next_value_table = [[0.00] * self.env.width 15 for _ in range(self.env.height)] 16 17 # Bellman Expectation Equation for the every states 18 for state in self.env.get_all_states(): 19 value = 0.0 20 # keep the value function of terminal states as 0(吸收态赋0) 21 if state == [2, 2]: 22 next_value_table[state[0]][state[1]] = value 23 continue 24 25 for action in self.env.possible_actions:#计算所有可能动作 26 next_state = self.env.state_after_action(state, action) 27 reward = self.env.get_reward(state, action) 28 next_value = self.get_value(next_state) 29 value += (self.get_policy(state)[action] * 30 (reward + self.discount_factor * next_value)) 31 32 next_value_table[state[0]][state[1]] = round(value, 2) 33 34 self.value_table = next_value_table 35 36 def policy_improvement(self):#策略改进 37 next_policy = self.policy_table 38 for state in self.env.get_all_states(): 39 if state == [2, 2]: 40 continue 41 value = -99999 42 max_index = [] 43 result = [0.0, 0.0, 0.0, 0.0] # initialize the policy 44 45 # for every actions, calculate 计算所有可能动作,保留取得最大值函数的动作 46 # [reward + (discount factor) * (next state value function)] 47 for index, action in enumerate(self.env.possible_actions): 48 next_state = self.env.state_after_action(state, action) 49 reward = self.env.get_reward(state, action) 50 next_value = self.get_value(next_state) 51 temp = reward + self.discount_factor * next_value 52 53 # We normally can't pick multiple actions in greedy policy. 54 # but here we allow multiple actions with same max values 允许多个取最大值函数的动作存在 55 if temp == value: 56 max_index.append(index) 57 elif temp > value: 58 value = temp 59 max_index.clear() 60 max_index.append(index) 61 62 # probability of action 63 prob = 1 / len(max_index) 64 65 for index in max_index: 66 result[index] = prob 67 68 next_policy[state[0]][state[1]] = result#更新策略表 69 70 self.policy_table = next_policy 71 72 # get action according to the current policy 73 def get_action(self, state): 74 random_pick = random.randrange(100) / 100 75 76 policy = self.get_policy(state) 77 policy_sum = 0.0 78 # return the action in the index 79 for index, value in enumerate(policy): 80 policy_sum += value 81 if random_pick < policy_sum: 82 return index 83 84 # get policy of specific state 85 def get_policy(self, state): 86 if state == [2, 2]: 87 return 0.0 88 return self.policy_table[state[0]][state[1]] 89 90 def get_value(self, state): 91 return round(self.value_table[state[0]][state[1]], 2)
桔桔桔桔桔桔桔桔桔桔