【Python】强化学习Q-Learning走迷宫

Q-Learning是一种基于值函数的强化学习算法,这里用该算法解决走迷宫问题。

算法步骤如下:

1. 初始化 Q 表:每个表格对应状态动作的 Q 值。这里就是一个H*W*4的表,4代表上下左右四个动作。

2. 选择动作: 根据 Q 表格选择最优动作或者以一定概率随机选择动作。

3. 执行动作,得到返回奖励(这里需要自定义,比如到达目标给的大的reward,撞墙给个小的reward)和下一个状态。

4. 更新 Q 表: 根据规则更新 Q 表格中对应状态动作的 Q 值。规则为 Q(s, a) = Q(s, a) + α*[r + γ*max(Q(s', a')) - Q(s, a)],其中 α 是学习率,γ 是折扣因子,r 是获得的奖励,s 是当前状态,a 是当前动作,s' 是下一个状态,a' 是在下一个状态下选择的最优动作。

5. 重复步骤 2-4: 不断与环境交互,选择动作、执行、更新 Q 值,直至满足停止条件(如达到最大迭代次数或者 Q 值收敛等)。

6. 最优策略提取: 通过学习得到的 Q 表格,可以提取最优策略,即在每个状态下选择具有最高 Q 值的动作。

代码如下:

import numpy as np
import random
import matplotlib.pyplot as plt
from PIL import Image
import imageio
import io

H = 30
W = 40

start = (0, random.randint(0, H-1))
goal = (W-1, random.randint(0, H-1))

img = Image.new('RGB', (W, H), (255, 255, 255))
pixels = img.load()

maze = np.zeros((W, H))
for h in range(H):
    for w in range(W):
        if random.random() < 0.1:
            maze[w, h] = -1

actions_num = 4
actions = [0, 1, 2, 3]
q_table = np.zeros((W, H, actions_num))
rate = 0.5
factor = 0.9
images = []

for i in range(2000):

    state = start
    path = [start]
    while(True):

        if np.random.rand() < 0.1:              #随机或者下一个状态最大q值对应的动作
            action = np.random.choice(actions)
        else:
            action = np.argmax(q_table[state])

        next_state = None                       #执行该动作
        if action == 0 and state[0] > 0:
            next_state = (state[0]-1, state[1])
        elif action == 1 and state[0] < W-1:
            next_state = (state[0]+1, state[1])
        elif action == 2 and state[1] > 0:
            next_state = (state[0], state[1]-1)
        elif action == 3 and state[1] < H-1:
            next_state = (state[0], state[1]+1)
        else:
            next_state = state

        if next_state == goal:                  #得到reward,到目标给大正反馈
            reward = 100
        elif maze[next_state] == -1:
            reward = -100                       #遇见障碍物给大负反馈
        else:
            reward = -1                         #走一步给小负反馈,走的步数越小,负反馈越小

        done = (state == goal)  

        if done:
            break

        current_q = q_table[state][action]      #根据公式更新qtable
        q_table[state][action] += rate * (reward + factor * max(q_table[next_state])  - current_q) 

        state = next_state
        path.append(state)

    if i % 100 == 0:                            #每100次看结果

        for h in range(H):
            for w in range(W):
                if maze[w,h]==-1:
                    pixels[w, h] = (0, 0, 0)
                else:
                    pixels[w, h] = (255, 255, 255)

        for x, y in path:
            pixels[x, y] = (0, 0, 255)

        pixels[start] = (255, 0, 0)
        pixels[goal] = (0, 255, 0)

        plt.clf()                           # 清除当前图形
        plt.imshow(img)
        plt.pause(0.1)                      # 暂停0.1秒,显示动态效果

        buf = io.BytesIO()
        plt.savefig(buf, format='png')      # 保存图像到内存中
        buf.seek(0)                         # 将文件指针移动到文件开头
        images.append(imageio.imread(buf))  # 从内存中读取图像并添加到列表中

plt.show()
imageio.mimsave('result.gif', images, fps=3)  # 保存为 GIF 图像,帧率为3

结果如下:

posted @ 2024-02-15 13:08  Dsp Tian  阅读(72)  评论(0编辑  收藏  举报