代码参考Prioritized Experience Replay (DQN)

  1 # -*- coding: utf-8 -*-
  2 import os
  3 import random
  4 import time as t
  5 
  6 import gym
  7 import keras
  8 import numpy as np
  9 import tensorflow as tf
 10 # 设置显存 0.2
 11 import tensorflow.python.keras.backend as backend
 12 from keras.layers import Dense, Input
 13 from keras.optimizers import Adam
 14 
 15 config = tf.compat.v1.ConfigProto()
 16 config.gpu_options.per_process_gpu_memory_fraction = 0.2
 17 backend.set_session(tf.compat.v1.Session(config=config))
 18 os.environ['CUDA_VISIBLE_DEVICES'] = '0'
 19 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
 20 
 21 EPISODES = 200  # 让 agent 玩游戏的次数
 22 
 23 
 24 class SumTree(object):
 25     """
 26     This SumTree code is a modified version and the original code is from:
 27     https://github.com/jaara/AI-blog/blob/master/SumTree.py
 28     Story data with its priority in the tree.
 29     """
 30     data_pointer = 0
 31 
 32     def __init__(self, capacity):
 33         self.capacity = capacity  # for all priority values
 34         self.tree = np.zeros(2 * capacity - 1)
 35         # [--------------Parent nodes-------------][-------leaves to recode priority-------]
 36         #             size: capacity - 1                       size: capacity
 37         self.data = np.zeros(capacity, dtype=object)  # for all transitions
 38         # [--------------data frame-------------]
 39         #             size: capacity
 40 
 41     def add(self, p, data):
 42         tree_idx = self.data_pointer + self.capacity - 1
 43         self.data[self.data_pointer] = data  # update data_frame
 44         self.update(tree_idx, p)  # update tree_frame
 45 
 46         self.data_pointer += 1
 47         if self.data_pointer >= self.capacity:  # replace when exceed the capacity
 48             self.data_pointer = 0
 49 
 50     def update(self, tree_idx, p):
 51         change = p - self.tree[tree_idx]
 52         self.tree[tree_idx] = p
 53         # then propagate the change through tree
 54         while tree_idx != 0:  # this method is faster than the recursive loop in the reference code
 55             tree_idx = (tree_idx - 1) // 2
 56             self.tree[tree_idx] += change
 57 
 58     def get_leaf(self, v):
 59         """
 60         Tree structure and array storage:
 61         Tree index:
 62              0         -> storing priority sum
 63             / \
 64           1     2
 65          / \   / \
 66         3   4 5   6    -> storing priority for transitions
 67         Array type for storing:
 68         [0,1,2,3,4,5,6]
 69         """
 70         parent_idx = 0
 71         while True:  # the while loop is faster than the method in the reference code
 72             cl_idx = 2 * parent_idx + 1  # this leaf's left and right kids
 73             cr_idx = cl_idx + 1
 74             if cl_idx >= len(self.tree):  # reach bottom, end search
 75                 leaf_idx = parent_idx
 76                 break
 77             else:  # downward search, always search for a higher priority node
 78                 if v <= self.tree[cl_idx]:
 79                     parent_idx = cl_idx
 80                 else:
 81                     v -= self.tree[cl_idx]
 82                     parent_idx = cr_idx
 83 
 84         data_idx = leaf_idx - self.capacity + 1
 85         return leaf_idx, self.tree[leaf_idx], self.data[data_idx]
 86 
 87     @property
 88     def total_p(self):
 89         return self.tree[0]  # the root
 90 
 91 
 92 class Memory(object):  # stored as ( s, a, r, s_ ) in SumTree
 93     """
 94     This Memory class is modified based on the original code from:
 95     https://github.com/jaara/AI-blog/blob/master/Seaquest-DDQN-PER.py
 96     """
 97     epsilon = 0.01  # small amount to avoid zero priority
 98     alpha = 0.6  # [0~1] convert the importance of TD error to priority
 99     beta = 0.4  # importance-sampling, from initial value increasing to 1
100     beta_increment_per_sampling = 0.001
101     abs_err_upper = 1.  # clipped abs error
102 
103     def __init__(self, capacity):
104         self.tree = SumTree(capacity)
105 
106     def store(self, transition):
107         max_p = np.max(self.tree.tree[-self.tree.capacity:])
108         if max_p == 0:
109             max_p = self.abs_err_upper
110         self.tree.add(max_p, transition)  # set the max p for new p
111 
112     def sample(self, n):
113         b_idx, b_memory, ISWeights = np.empty((n,), dtype=np.int32), [[]] * n, np.empty((n, 1))
114         pri_seg = self.tree.total_p / n  # priority segment
115         self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])  # max = 1
116 
117         min_prob = np.min(self.tree.tree.nonzero()) / self.tree.total_p  # for later calculate ISweight
118         for i in range(n):
119             a, b = pri_seg * i, pri_seg * (i + 1)
120             v = np.random.uniform(a, b)
121             idx, p, data = self.tree.get_leaf(v)
122             prob = p / self.tree.total_p
123             ISWeights[i, 0] = np.power(prob / min_prob, -self.beta) if min_prob else 0
124             b_idx[i], b_memory[i] = idx, data
125         return b_idx, b_memory, ISWeights
126 
127     def batch_update(self, tree_idx, abs_errors):
128         abs_errors += self.epsilon  # convert to abs and avoid 0
129         clipped_errors = np.minimum(abs_errors, self.abs_err_upper)
130         ps = np.power(clipped_errors, self.alpha)
131         for ti, p in zip(tree_idx, ps):
132             self.tree.update(ti, p)
133 
134 
135 class DQNAgent:
136     def __init__(self, state_shape, action_size, batch_size):
137         self.ISWeights = np.zeros(batch_size)
138         self.batch_size = batch_size
139         self.state_shape = state_shape
140         self.action_size = action_size
141         self.gamma = 0.95  # 计算未来奖励时的折算率
142         self.epsilon = 0.5  # agent 最初探索环境时选择 action 的探索率
143         self.epsilon_min = 0.01  # agent 控制随机探索的阈值
144         self.epsilon_decay = 0.995  # 随着 agent 玩游戏越来越好,降低探索率
145         self.learning_rate = 0.001
146         self.replace_target_iter = 300  # 多少次替换target_model参数
147         self.learning_rate = 0.001
148         self.learn_step_counter = 0  # total learning step
149         self.memory = Memory(capacity=10000)
150         self.model = self._build_model()
151         self.target_model = self._build_model()
152 
153     def my_loss(self, y_true, y_pred):
154         return tf.math.reduce_mean(self.ISWeights * tf.math.squared_difference(y_pred, y_true))
155 
156     def _build_model(self):
157         inputs = Input(name='inputs', shape=self.state_shape)
158         x = Dense(24, activation='relu')(inputs)
159         x = Dense(24, activation='relu')(x)
160         y_pred = Dense(self.action_size, name="predictions")(x)
161         # y_pred = Activation('softmax', name='softmax')(x)
162 
163         model = keras.Model(inputs=inputs, outputs=y_pred)
164         model.compile(loss=self.my_loss,
165                       optimizer=Adam(lr=self.learning_rate))
166         model.summary()
167         return model
168 
169     def remember(self, state, action, reward, next_state):
170         transition = (state, action, reward, next_state)
171         self.memory.store(transition)
172 
173     def act(self, state):
174         if np.random.rand() <= self.epsilon:
175             return random.randrange(self.action_size)
176         act_values = self.model.predict(np.expand_dims(state, axis=0))
177         return np.argmax(act_values[0])
178 
179     def replay(self, batch_size):
180         tree_idx, minibatch, self.ISWeights = self.memory.sample(batch_size)
181         states = []
182         abs_errors = []
183         q_targets = []
184         for i, (state, action, reward, next_state) in enumerate(minibatch):
185             states.append(state)
186             # 更新target_model参数
187             if self.learn_step_counter % self.replace_target_iter == 0:
188                 self.target_model.set_weights(self.model.get_weights())
189             # Double DQN
190             q_eval_next = self.model.predict(np.expand_dims(next_state, axis=0))[0]
191             q_next = self.target_model.predict(np.expand_dims(next_state, axis=0))[0]
192             # q_eval 得出的最高奖励动作
193             max_action_next = np.argmax(q_eval_next)
194             # Double DQN 选择 q_next 依据 q_eval 选出的动作
195             target = (reward + self.gamma * q_next[max_action_next])
196             q_target = self.model.predict(np.expand_dims(state, axis=0))[0]
197             abs_error = np.abs(target - q_target[action])
198             abs_errors.append(abs_error)
199             q_target[action] = target
200             q_targets.append(q_target)
201         self.memory.batch_update(tree_idx, abs_errors)  # update priority
202         self.model.fit(np.asanyarray(states), np.asanyarray(q_targets), epochs=1, verbose=0)
203         self.model.set_weights(self.model.get_weights())
204         if self.epsilon > self.epsilon_min:
205             self.epsilon *= self.epsilon_decay
206 
207 
208 if __name__ == "__main__":
209     # 初始化 gym 环境和 agent
210     # Breakout-v4
211     batch_size = 32
212     env = gym.make('CartPole-v1')
213     state_size = env.observation_space.shape
214     action_size = env.action_space.n
215     agent = DQNAgent(state_size, action_size, batch_size)
216     # agent.model.load_weights('model_weights.h5')
217     done = False
218 
219     # 开始迭代游戏
220     for e in range(EPISODES):
221         start = t.time()
222         # 每次游戏开始时都重新设置一下状态
223         state = env.reset()
224 
225         # time 代表游戏的每一帧,
226         # 每成功保持杆平衡一次得分就加 1,最高到 500 分,
227         # 目标是希望分数越高越好
228         for time in range(500):
229             # 每一帧时,agent 根据 state 选择 action
230             action = agent.act(state)
231             # 这个 action 使得游戏进入下一个状态 next_state,并且拿到了奖励 reward
232             # 如果杆依旧平衡则 reward 为 1,游戏结束则为 -10
233             # env.render()
234             next_state, reward, done, _ = env.step(action)
235             reward = reward if not done else -10
236 
237             # 记忆之前的信息:state, action, reward, and done
238             agent.remember(state, action, reward, next_state)
239 
240             # 更新下一帧的所在状态
241             state = next_state
242 
243             # 如果杆倒了,则游戏结束,打印分数
244             if done:
245                 print("episode: {}/{}, score: {}, e: {:.2}, time: {}"
246                       .format(e, EPISODES, time, agent.epsilon, t.time() - start))
247                 agent.model.save_weights('model_weights.h5')
248                 # env.close()
249                 break
250 
251             # 用之前的经验训练 agent
252             if agent.memory.tree.data_pointer > batch_size:
253                 agent.replay(batch_size)
posted on 2021-01-05 15:51  蒟蒻、  阅读(176)  评论(1)    收藏  举报