代码参考Prioritized Experience Replay (DQN)
1 # -*- coding: utf-8 -*- 2 import os 3 import random 4 import time as t 5 6 import gym 7 import keras 8 import numpy as np 9 import tensorflow as tf 10 # 设置显存 0.2 11 import tensorflow.python.keras.backend as backend 12 from keras.layers import Dense, Input 13 from keras.optimizers import Adam 14 15 config = tf.compat.v1.ConfigProto() 16 config.gpu_options.per_process_gpu_memory_fraction = 0.2 17 backend.set_session(tf.compat.v1.Session(config=config)) 18 os.environ['CUDA_VISIBLE_DEVICES'] = '0' 19 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 20 21 EPISODES = 200 # 让 agent 玩游戏的次数 22 23 24 class SumTree(object): 25 """ 26 This SumTree code is a modified version and the original code is from: 27 https://github.com/jaara/AI-blog/blob/master/SumTree.py 28 Story data with its priority in the tree. 29 """ 30 data_pointer = 0 31 32 def __init__(self, capacity): 33 self.capacity = capacity # for all priority values 34 self.tree = np.zeros(2 * capacity - 1) 35 # [--------------Parent nodes-------------][-------leaves to recode priority-------] 36 # size: capacity - 1 size: capacity 37 self.data = np.zeros(capacity, dtype=object) # for all transitions 38 # [--------------data frame-------------] 39 # size: capacity 40 41 def add(self, p, data): 42 tree_idx = self.data_pointer + self.capacity - 1 43 self.data[self.data_pointer] = data # update data_frame 44 self.update(tree_idx, p) # update tree_frame 45 46 self.data_pointer += 1 47 if self.data_pointer >= self.capacity: # replace when exceed the capacity 48 self.data_pointer = 0 49 50 def update(self, tree_idx, p): 51 change = p - self.tree[tree_idx] 52 self.tree[tree_idx] = p 53 # then propagate the change through tree 54 while tree_idx != 0: # this method is faster than the recursive loop in the reference code 55 tree_idx = (tree_idx - 1) // 2 56 self.tree[tree_idx] += change 57 58 def get_leaf(self, v): 59 """ 60 Tree structure and array storage: 61 Tree index: 62 0 -> storing priority sum 63 / \ 64 1 2 65 / \ / \ 66 3 4 5 6 -> storing priority for transitions 67 Array type for storing: 68 [0,1,2,3,4,5,6] 69 """ 70 parent_idx = 0 71 while True: # the while loop is faster than the method in the reference code 72 cl_idx = 2 * parent_idx + 1 # this leaf's left and right kids 73 cr_idx = cl_idx + 1 74 if cl_idx >= len(self.tree): # reach bottom, end search 75 leaf_idx = parent_idx 76 break 77 else: # downward search, always search for a higher priority node 78 if v <= self.tree[cl_idx]: 79 parent_idx = cl_idx 80 else: 81 v -= self.tree[cl_idx] 82 parent_idx = cr_idx 83 84 data_idx = leaf_idx - self.capacity + 1 85 return leaf_idx, self.tree[leaf_idx], self.data[data_idx] 86 87 @property 88 def total_p(self): 89 return self.tree[0] # the root 90 91 92 class Memory(object): # stored as ( s, a, r, s_ ) in SumTree 93 """ 94 This Memory class is modified based on the original code from: 95 https://github.com/jaara/AI-blog/blob/master/Seaquest-DDQN-PER.py 96 """ 97 epsilon = 0.01 # small amount to avoid zero priority 98 alpha = 0.6 # [0~1] convert the importance of TD error to priority 99 beta = 0.4 # importance-sampling, from initial value increasing to 1 100 beta_increment_per_sampling = 0.001 101 abs_err_upper = 1. # clipped abs error 102 103 def __init__(self, capacity): 104 self.tree = SumTree(capacity) 105 106 def store(self, transition): 107 max_p = np.max(self.tree.tree[-self.tree.capacity:]) 108 if max_p == 0: 109 max_p = self.abs_err_upper 110 self.tree.add(max_p, transition) # set the max p for new p 111 112 def sample(self, n): 113 b_idx, b_memory, ISWeights = np.empty((n,), dtype=np.int32), [[]] * n, np.empty((n, 1)) 114 pri_seg = self.tree.total_p / n # priority segment 115 self.beta = np.min([1., self.beta + self.beta_increment_per_sampling]) # max = 1 116 117 min_prob = np.min(self.tree.tree.nonzero()) / self.tree.total_p # for later calculate ISweight 118 for i in range(n): 119 a, b = pri_seg * i, pri_seg * (i + 1) 120 v = np.random.uniform(a, b) 121 idx, p, data = self.tree.get_leaf(v) 122 prob = p / self.tree.total_p 123 ISWeights[i, 0] = np.power(prob / min_prob, -self.beta) if min_prob else 0 124 b_idx[i], b_memory[i] = idx, data 125 return b_idx, b_memory, ISWeights 126 127 def batch_update(self, tree_idx, abs_errors): 128 abs_errors += self.epsilon # convert to abs and avoid 0 129 clipped_errors = np.minimum(abs_errors, self.abs_err_upper) 130 ps = np.power(clipped_errors, self.alpha) 131 for ti, p in zip(tree_idx, ps): 132 self.tree.update(ti, p) 133 134 135 class DQNAgent: 136 def __init__(self, state_shape, action_size, batch_size): 137 self.ISWeights = np.zeros(batch_size) 138 self.batch_size = batch_size 139 self.state_shape = state_shape 140 self.action_size = action_size 141 self.gamma = 0.95 # 计算未来奖励时的折算率 142 self.epsilon = 0.5 # agent 最初探索环境时选择 action 的探索率 143 self.epsilon_min = 0.01 # agent 控制随机探索的阈值 144 self.epsilon_decay = 0.995 # 随着 agent 玩游戏越来越好,降低探索率 145 self.learning_rate = 0.001 146 self.replace_target_iter = 300 # 多少次替换target_model参数 147 self.learning_rate = 0.001 148 self.learn_step_counter = 0 # total learning step 149 self.memory = Memory(capacity=10000) 150 self.model = self._build_model() 151 self.target_model = self._build_model() 152 153 def my_loss(self, y_true, y_pred): 154 return tf.math.reduce_mean(self.ISWeights * tf.math.squared_difference(y_pred, y_true)) 155 156 def _build_model(self): 157 inputs = Input(name='inputs', shape=self.state_shape) 158 x = Dense(24, activation='relu')(inputs) 159 x = Dense(24, activation='relu')(x) 160 y_pred = Dense(self.action_size, name="predictions")(x) 161 # y_pred = Activation('softmax', name='softmax')(x) 162 163 model = keras.Model(inputs=inputs, outputs=y_pred) 164 model.compile(loss=self.my_loss, 165 optimizer=Adam(lr=self.learning_rate)) 166 model.summary() 167 return model 168 169 def remember(self, state, action, reward, next_state): 170 transition = (state, action, reward, next_state) 171 self.memory.store(transition) 172 173 def act(self, state): 174 if np.random.rand() <= self.epsilon: 175 return random.randrange(self.action_size) 176 act_values = self.model.predict(np.expand_dims(state, axis=0)) 177 return np.argmax(act_values[0]) 178 179 def replay(self, batch_size): 180 tree_idx, minibatch, self.ISWeights = self.memory.sample(batch_size) 181 states = [] 182 abs_errors = [] 183 q_targets = [] 184 for i, (state, action, reward, next_state) in enumerate(minibatch): 185 states.append(state) 186 # 更新target_model参数 187 if self.learn_step_counter % self.replace_target_iter == 0: 188 self.target_model.set_weights(self.model.get_weights()) 189 # Double DQN 190 q_eval_next = self.model.predict(np.expand_dims(next_state, axis=0))[0] 191 q_next = self.target_model.predict(np.expand_dims(next_state, axis=0))[0] 192 # q_eval 得出的最高奖励动作 193 max_action_next = np.argmax(q_eval_next) 194 # Double DQN 选择 q_next 依据 q_eval 选出的动作 195 target = (reward + self.gamma * q_next[max_action_next]) 196 q_target = self.model.predict(np.expand_dims(state, axis=0))[0] 197 abs_error = np.abs(target - q_target[action]) 198 abs_errors.append(abs_error) 199 q_target[action] = target 200 q_targets.append(q_target) 201 self.memory.batch_update(tree_idx, abs_errors) # update priority 202 self.model.fit(np.asanyarray(states), np.asanyarray(q_targets), epochs=1, verbose=0) 203 self.model.set_weights(self.model.get_weights()) 204 if self.epsilon > self.epsilon_min: 205 self.epsilon *= self.epsilon_decay 206 207 208 if __name__ == "__main__": 209 # 初始化 gym 环境和 agent 210 # Breakout-v4 211 batch_size = 32 212 env = gym.make('CartPole-v1') 213 state_size = env.observation_space.shape 214 action_size = env.action_space.n 215 agent = DQNAgent(state_size, action_size, batch_size) 216 # agent.model.load_weights('model_weights.h5') 217 done = False 218 219 # 开始迭代游戏 220 for e in range(EPISODES): 221 start = t.time() 222 # 每次游戏开始时都重新设置一下状态 223 state = env.reset() 224 225 # time 代表游戏的每一帧, 226 # 每成功保持杆平衡一次得分就加 1,最高到 500 分, 227 # 目标是希望分数越高越好 228 for time in range(500): 229 # 每一帧时,agent 根据 state 选择 action 230 action = agent.act(state) 231 # 这个 action 使得游戏进入下一个状态 next_state,并且拿到了奖励 reward 232 # 如果杆依旧平衡则 reward 为 1,游戏结束则为 -10 233 # env.render() 234 next_state, reward, done, _ = env.step(action) 235 reward = reward if not done else -10 236 237 # 记忆之前的信息:state, action, reward, and done 238 agent.remember(state, action, reward, next_state) 239 240 # 更新下一帧的所在状态 241 state = next_state 242 243 # 如果杆倒了,则游戏结束,打印分数 244 if done: 245 print("episode: {}/{}, score: {}, e: {:.2}, time: {}" 246 .format(e, EPISODES, time, agent.epsilon, t.time() - start)) 247 agent.model.save_weights('model_weights.h5') 248 # env.close() 249 break 250 251 # 用之前的经验训练 agent 252 if agent.memory.tree.data_pointer > batch_size: 253 agent.replay(batch_size)
浙公网安备 33010602011771号