时序差分学习在路径规划中的MATLAB仿真
时序差分(Temporal Difference, TD)学习是强化学习中的一种重要方法,它结合了动态规划和蒙特卡罗方法的优点。在路径规划问题中,TD学习可以帮助智能体通过与环境交互来学习最优路径。
1. 环境设置
首先,我们需要创建一个网格世界环境,其中包含起点、终点、障碍物和奖励。
function gridWorld = createGridWorld(width, height, start, goal, obstacles)
% 创建网格世界环境
% 输入参数:
% width - 网格宽度
% height - 网格高度
% start - 起点坐标 [x, y]
% goal - 终点坐标 [x, y]
% obstacles - 障碍物坐标矩阵 [x1, y1; x2, y2; ...]
gridWorld.width = width;
gridWorld.height = height;
gridWorld.start = start;
gridWorld.goal = goal;
gridWorld.obstacles = obstacles;
% 定义动作: 上(1), 右(2), 下(3), 左(4)
gridWorld.actions = [0, -1; 1, 0; 0, 1; -1, 0]; % [dx, dy]
gridWorld.actionNames = {'上', '右', '下', '左'};
% 定义奖励
gridWorld.rewardGoal = 10; % 到达目标的奖励
gridWorld.rewardObstacle = -10; % 碰到障碍物的惩罚
gridWorld.rewardStep = -1; % 每步的惩罚
end
2. Q-learning算法实现
Q-learning是一种常用的时序差分学习算法,它通过学习状态-动作值函数来找到最优策略。
function [Q, policy, stats] = qLearning(gridWorld, params)
% Q-learning算法实现
% 输入参数:
% gridWorld - 网格世界环境
% params - 算法参数
% 输出:
% Q - 状态-动作值函数
% policy - 最优策略
% stats - 训练统计信息
% 初始化Q表
Q = zeros(gridWorld.height, gridWorld.width, 4);
% 初始化统计信息
stats.episodeLengths = zeros(params.numEpisodes, 1);
stats.episodeRewards = zeros(params.numEpisodes, 1);
stats.episodeSuccess = zeros(params.numEpisodes, 1);
% 训练循环
for episode = 1:params.numEpisodes
% 初始化状态
state = gridWorld.start;
done = false;
step = 0;
totalReward = 0;
% ε-贪婪策略的探索率衰减
epsilon = params.epsilonMin + (params.epsilonMax - params.epsilonMin) * ...
exp(-params.epsilonDecay * episode);
% 单个episode循环
while ~done && step < params.maxSteps
% 选择动作 (ε-贪婪策略)
if rand() < epsilon
action = randi(4); % 随机探索
else
[~, action] = max(Q(state(2), state(1), :)); % 利用
end
% 执行动作,得到下一个状态和奖励
[nextState, reward, done] = stepEnvironment(gridWorld, state, action);
% Q-learning更新
currentQ = Q(state(2), state(1), action);
nextMaxQ = max(Q(nextState(2), nextState(1), :));
% Q值更新公式
Q(state(2), state(1), action) = currentQ + params.alpha * ...
(reward + params.gamma * nextMaxQ - currentQ);
% 更新状态和统计信息
state = nextState;
step = step + 1;
totalReward = totalReward + reward;
end
% 记录统计信息
stats.episodeLengths(episode) = step;
stats.episodeRewards(episode) = totalReward;
stats.episodeSuccess(episode) = done && isequal(state, gridWorld.goal);
% 每100个episode显示进度
if mod(episode, 100) == 0
fprintf('Episode %d/%d, Steps: %d, Reward: %.2f, Success: %d\n', ...
episode, params.numEpisodes, step, totalReward, stats.episodeSuccess(episode));
end
end
% 从Q表中提取最优策略
policy = extractPolicy(Q);
end
function [nextState, reward, done] = stepEnvironment(gridWorld, state, action)
% 环境模拟函数
% 输入:
% gridWorld - 网格世界环境
% state - 当前状态 [x, y]
% action - 执行的动作 (1-4)
% 输出:
% nextState - 下一个状态
% reward - 获得的奖励
% done - 是否终止
% 计算下一个状态
move = gridWorld.actions(action, :);
nextState = state + move;
% 检查是否超出边界
if nextState(1) < 1 || nextState(1) > gridWorld.width || ...
nextState(2) < 1 || nextState(2) > gridWorld.height
nextState = state; % 保持原地
end
% 检查是否到达目标
if isequal(nextState, gridWorld.goal)
reward = gridWorld.rewardGoal;
done = true;
return;
end
% 检查是否碰到障碍物
isObstacle = false;
for i = 1:size(gridWorld.obstacles, 1)
if isequal(nextState, gridWorld.obstacles(i, :))
isObstacle = true;
break;
end
end
if isObstacle
reward = gridWorld.rewardObstacle;
done = true;
nextState = state; % 碰到障碍物,保持原地
return;
end
% 普通移动
reward = gridWorld.rewardStep;
done = false;
end
function policy = extractPolicy(Q)
% 从Q表中提取最优策略
[height, width, ~] = size(Q);
policy = zeros(height, width);
for y = 1:height
for x = 1:width
[~, policy(y, x)] = max(Q(y, x, :));
end
end
end
3. SARSA算法实现
SARSA是另一种时序差分学习算法,它与Q-learning的主要区别在于更新策略。
function [Q, policy, stats] = sarsa(gridWorld, params)
% SARSA算法实现
% 输入参数:
% gridWorld - 网格世界环境
% params - 算法参数
% 输出:
% Q - 状态-动作值函数
% policy - 最优策略
% stats - 训练统计信息
% 初始化Q表
Q = zeros(gridWorld.height, gridWorld.width, 4);
% 初始化统计信息
stats.episodeLengths = zeros(params.numEpisodes, 1);
stats.episodeRewards = zeros(params.numEpisodes, 1);
stats.episodeSuccess = zeros(params.numEpisodes, 1);
% 训练循环
for episode = 1:params.numEpisodes
% 初始化状态
state = gridWorld.start;
done = false;
step = 0;
totalReward = 0;
% ε-贪婪策略的探索率衰减
epsilon = params.epsilonMin + (params.epsilonMax - params.epsilonMin) * ...
exp(-params.epsilonDecay * episode);
% 选择初始动作
if rand() < epsilon
action = randi(4); % 随机探索
else
[~, action] = max(Q(state(2), state(1), :)); % 利用
end
% 单个episode循环
while ~done && step < params.maxSteps
% 执行动作,得到下一个状态和奖励
[nextState, reward, done] = stepEnvironment(gridWorld, state, action);
% 选择下一个动作
if rand() < epsilon
nextAction = randi(4); % 随机探索
else
[~, nextAction] = max(Q(nextState(2), nextState(1), :)); % 利用
end
% SARSA更新
currentQ = Q(state(2), state(1), action);
nextQ = Q(nextState(2), nextState(1), nextAction);
% Q值更新公式
Q(state(2), state(1), action) = currentQ + params.alpha * ...
(reward + params.gamma * nextQ - currentQ);
% 更新状态和动作
state = nextState;
action = nextAction;
step = step + 1;
totalReward = totalReward + reward;
end
% 记录统计信息
stats.episodeLengths(episode) = step;
stats.episodeRewards(episode) = totalReward;
stats.episodeSuccess(episode) = done && isequal(state, gridWorld.goal);
% 每100个episode显示进度
if mod(episode, 100) == 0
fprintf('Episode %d/%d, Steps: %d, Reward: %.2f, Success: %d\n', ...
episode, params.numEpisodes, step, totalReward, stats.episodeSuccess(episode));
end
end
% 从Q表中提取最优策略
policy = extractPolicy(Q);
end
4. 可视化函数
为了直观地展示学习过程和结果,我们需要一些可视化函数。
function plotGridWorld(gridWorld, policy, path)
% 绘制网格世界和策略
figure;
hold on;
axis equal;
axis([0, gridWorld.width, 0, gridWorld.height]);
% 绘制网格
for x = 0:gridWorld.width
plot([x, x], [0, gridWorld.height], 'k-');
end
for y = 0:gridWorld.height
plot([0, gridWorld.width], [y, y], 'k-');
end
% 绘制起点
rectangle('Position', [gridWorld.start(1)-0.5, gridWorld.start(2)-0.5, 1, 1], ...
'FaceColor', 'g', 'EdgeColor', 'none');
text(gridWorld.start(1), gridWorld.start(2), '起点', ...
'HorizontalAlignment', 'center', 'VerticalAlignment', 'middle');
% 绘制终点
rectangle('Position', [gridWorld.goal(1)-0.5, gridWorld.goal(2)-0.5, 1, 1], ...
'FaceColor', 'r', 'EdgeColor', 'none');
text(gridWorld.goal(1), gridWorld.goal(2), '终点', ...
'HorizontalAlignment', 'center', 'VerticalAlignment', 'middle');
% 绘制障碍物
for i = 1:size(gridWorld.obstacles, 1)
obs = gridWorld.obstacles(i, :);
rectangle('Position', [obs(1)-0.5, obs(2)-0.5, 1, 1], ...
'FaceColor', 'k', 'EdgeColor', 'none');
end
% 绘制策略箭头
if nargin >= 2
arrowScale = 0.3;
for y = 1:gridWorld.height
for x = 1:gridWorld.width
% 跳过起点、终点和障碍物
if isequal([x, y], gridWorld.start) || ...
isequal([x, y], gridWorld.goal) || ...
isObstacle(gridWorld, [x, y])
continue;
end
action = policy(y, x);
dx = gridWorld.actions(action, 1) * arrowScale;
dy = gridWorld.actions(action, 2) * arrowScale;
quiver(x, y, dx, dy, 'AutoScale', 'off', ...
'MaxHeadSize', 1, 'LineWidth', 1.5, 'Color', 'b');
end
end
end
% 绘制路径
if nargin >= 3 && ~isempty(path)
pathX = path(:, 1);
pathY = path(:, 2);
plot(pathX, pathY, 'm-', 'LineWidth', 2);
plot(pathX, pathY, 'mo', 'MarkerSize', 6, 'MarkerFaceColor', 'm');
end
hold off;
title('网格世界与学习策略');
end
function plotLearningStats(stats, algorithmName)
% 绘制学习统计信息
figure;
% 绘制每episode的步数
subplot(3, 1, 1);
plot(stats.episodeLengths);
xlabel('Episode');
ylabel('步数');
title(sprintf('%s - 每Episode步数', algorithmName));
grid on;
% 绘制每episode的奖励
subplot(3, 1, 2);
plot(stats.episodeRewards);
xlabel('Episode');
ylabel('奖励');
title(sprintf('%s - 每Episode奖励', algorithmName));
grid on;
% 绘制成功率(滑动平均)
subplot(3, 1, 3);
windowSize = 50;
successRate = movmean(stats.episodeSuccess, windowSize);
plot(successRate);
xlabel('Episode');
ylabel('成功率');
title(sprintf('%s - 成功率 (%d-episode移动平均)', algorithmName, windowSize));
grid on;
% 调整布局
set(gcf, 'Position', [100, 100, 800, 600]);
end
function path = findPath(gridWorld, policy)
% 根据策略找到从起点到终点的路径
state = gridWorld.start;
path = state;
visited = zeros(gridWorld.height, gridWorld.width);
visited(state(2), state(1)) = 1;
maxSteps = gridWorld.width * gridWorld.height * 2; % 防止无限循环
step = 0;
while ~isequal(state, gridWorld.goal) && step < maxSteps
action = policy(state(2), state(1));
move = gridWorld.actions(action, :);
nextState = state + move;
% 检查是否有效移动
if nextState(1) < 1 || nextState(1) > gridWorld.width || ...
nextState(2) < 1 || nextState(2) > gridWorld.height || ...
isObstacle(gridWorld, nextState)
break; % 无效移动,退出循环
end
% 检查是否访问过该状态(防止循环)
if visited(nextState(2), nextState(1))
break; % 已经访问过,退出循环
end
state = nextState;
path = [path; state];
visited(state(2), state(1)) = 1;
step = step + 1;
end
end
function result = isObstacle(gridWorld, state)
% 检查给定状态是否为障碍物
result = false;
for i = 1:size(gridWorld.obstacles, 1)
if isequal(state, gridWorld.obstacles(i, :))
result = true;
return;
end
end
end
5. 主函数和参数设置
最后,我们创建一个主函数来整合所有组件,并设置算法参数。
function main()
% 主函数:时序差分学习路径规划仿真
% 创建网格世界环境
width = 10;
height = 10;
start = [1, 1];
goal = [10, 10];
obstacles = [3, 3; 3, 4; 3, 5; 4, 5; 5, 5; 6, 5; 7, 5; 8, 5; 8, 4; 8, 3];
gridWorld = createGridWorld(width, height, start, goal, obstacles);
% 设置算法参数
params.alpha = 0.1; % 学习率
params.gamma = 0.9; % 折扣因子
params.epsilonMax = 0.9; % 初始探索率
params.epsilonMin = 0.1; % 最小探索率
params.epsilonDecay = 0.001; % 探索率衰减率
params.numEpisodes = 1000; % 训练episode数量
params.maxSteps = 100; % 每个episode最大步数
% 训练Q-learning算法
fprintf('训练Q-learning算法...\n');
[Q_qlearning, policy_qlearning, stats_qlearning] = qLearning(gridWorld, params);
% 训练SARSA算法
fprintf('训练SARSA算法...\n');
[Q_sarsa, policy_sarsa, stats_sarsa] = sarsa(gridWorld, params);
% 可视化结果
% 1. 绘制学习曲线
plotLearningStats(stats_qlearning, 'Q-learning');
plotLearningStats(stats_sarsa, 'SARSA');
% 2. 绘制网格世界和策略
path_qlearning = findPath(gridWorld, policy_qlearning);
path_sarsa = findPath(gridWorld, policy_sarsa);
plotGridWorld(gridWorld, policy_qlearning, path_qlearning);
title('Q-learning策略与路径');
plotGridWorld(gridWorld, policy_sarsa, path_sarsa);
title('SARSA策略与路径');
% 3. 比较两种算法的性能
fprintf('算法性能比较:\n');
fprintf('Q-learning - 平均步数: %.2f, 平均奖励: %.2f, 成功率: %.2f%%\n', ...
mean(stats_qlearning.episodeLengths), ...
mean(stats_qlearning.episodeRewards), ...
mean(stats_qlearning.episodeSuccess) * 100);
fprintf('SARSA - 平均步数: %.2f, 平均奖励: %.2f, 成功率: %.2f%%\n', ...
mean(stats_sarsa.episodeLengths), ...
mean(stats_sarsa.episodeRewards), ...
mean(stats_sarsa.episodeSuccess) * 100);
% 4. 显示最优路径
fprintf('Q-learning路径长度: %d\n', size(path_qlearning, 1));
fprintf('SARSA路径长度: %d\n', size(path_sarsa, 1));
end
6. 运行仿真
要运行这个仿真,只需在MATLAB命令窗口中调用主函数:
main();
推荐代码 时序差分学习做路径规划的仿真 www.3dddown.com/cna/50855.html
7. 扩展功能
你可以进一步扩展这个仿真,例如:
-
添加更多算法:实现其他时序差分学习算法,如Expected SARSA、Double Q-learning等。
-
复杂环境:创建更复杂的环境,如随机生成的迷宫、动态障碍物等。
-
函数逼近:对于大型状态空间,可以使用函数逼近方法(如神经网络)代替Q表。
-
多智能体:扩展为多智能体路径规划问题。
-
实时可视化:在训练过程中实时显示智能体的学习和探索过程。
浙公网安备 33010602011771号