Q-learning

Q-learning(Notebook)

environment
!apt-get update
!apt install -y python3.9

!pip install virtualenv

%cd /kaggle/working
!virtualenv venv -p $(which python3.9)
# !virtualenv myenv

!python3.9 --version

!pip install gym==0.26.2
!pip install pettingzoo==1.23.1

# import torch
# torch.__version__
# '2.4.1+cu121'

Code

1.MyWrapper


#定义环境
class MyWrapper(gym.Wrapper):
    def __init__(self):
        #is_sliooer控制会不会滑,
        #gym.make()生成环境,这里的环境是游戏进程
        env = gym.make('FrozenLake-v1',
                       render_mode='rgb_array',
                      is_slippery=False)

        super().__init__(env)
        self.env = env
        self.counter = 0

    def rest(self):
        #重置环境,让小车回到起点,并输出初始状态
        state, _ = self.reset()
        return state

    def step(self, action):
        #智能体真正执行动作。然后环境更新状态,并反馈一个奖励
        state, reward, terminated, truncated, info = self.env.step(action)
        over = terminated or truncated

        #走一步扣一分,逼迫机器人尽快结束游戏
        if not over:
            reward = -1

        #掉坑扣100分
        if over and reward == 0:
            reward = -100

        return state, reward, over

    #打印游戏图像
    def show(self):
        plt.figure(figsize=(3,3))
        plt.imshow(self.render())
        plt.savefig(f'/kaggle/working/step_{self.counter}.png')
        plt.show()
        self.counter += 1


2.Q


import numpy as np
#初始化Q表,定义了每个状态下每个动作的价值
#游戏环境是4x4=16个状态,每个格子中都可以做上下左右4个动作
Q = np.zeros((16,4))

3.Play


#玩一局游戏并记录数据
def play(show=False):
    data = []
    reward_sum = 0

    state, _ = env.reset()
    over  = False

    while not over:
        #首先我们有一个state,根据Q表当中查state最高分数的action是哪一个,取最高分数的action执行
        action = Q[state].argmax()
        #给动作增加一定的随机性,不希望机器人太死板,让它有10%的概率采取随机动作
        if random.random() < 0.1:
            #此处均匀抽样生成一个动作。在实际应用中,应当依据状态,用策略函数生成动作
            action = env.action_space.sample()

        next_state, reward, over = env.step(action)

        data.append((state, action, reward, next_state, over))
        reward_sum += reward

        state = next_state

        if show:
            time.sleep(1)
            display.clear_output(wait=True)
            env.show()

    return data, reward_sum


4.Pool


#数据池
class Pool:

    def __init__(self):
        self.pool = []

    def __len__(self):
        return len(self.pool)

    def __getitem__(self, i):
        return self.pool[i]

    #更新动作池
    def update(self):
        #每次更新不少于N条新数据
        old_len = len(self.pool)
        while len(pool) - old_len < 200:
            self.pool.extend(play()[0])

        #只保留最新的N条数据
        self.pool = self.pool[-1_0000:]

    #获取一批数据样本
    def sample(self):
        return random.choice(self.pool)


5.Train


#训练
def train():
    #共更新N轮数据
    for epoch in range(1000):
        pool.update()

        #每次更新数据后,训练N次
        for i in range(200):

            #随机抽一条数据
            state, action, reward, next_state, over = pool.sample()

            #Q矩阵当前估计的state下action的价值
            value = Q[state, action]

            #实际玩了之后得到的reward+下一个状态的价值*0.9
            target = reward + Q[next_state].max() * 0.9

            #value和target应该是相等的,说明Q矩阵的评估准确
            #如果有误差,则应该以target为准更新Q表,修正它的偏差
            #这就是TD误差,指评估值之间的偏差,以实际成分高的评估为准进行修正
            update = (target - value) * 0.1

            #更新Q表
            Q[state, action] += update

        if epoch % 100 == 0:
            print(epoch, len(pool), play()[-1])


posted @ 2024-12-28 18:51  HaibaraYuki  阅读(19)  评论(0)    收藏  举报