时序差分学习在路径规划中的MATLAB仿真

时序差分(Temporal Difference, TD)学习是强化学习中的一种重要方法,它结合了动态规划和蒙特卡罗方法的优点。在路径规划问题中,TD学习可以帮助智能体通过与环境交互来学习最优路径。

1. 环境设置

首先,我们需要创建一个网格世界环境,其中包含起点、终点、障碍物和奖励。

function gridWorld = createGridWorld(width, height, start, goal, obstacles)
    % 创建网格世界环境
    % 输入参数:
    %   width - 网格宽度
    %   height - 网格高度
    %   start - 起点坐标 [x, y]
    %   goal - 终点坐标 [x, y]
    %   obstacles - 障碍物坐标矩阵 [x1, y1; x2, y2; ...]
    
    gridWorld.width = width;
    gridWorld.height = height;
    gridWorld.start = start;
    gridWorld.goal = goal;
    gridWorld.obstacles = obstacles;
    
    % 定义动作: 上(1), 右(2), 下(3), 左(4)
    gridWorld.actions = [0, -1; 1, 0; 0, 1; -1, 0]; % [dx, dy]
    gridWorld.actionNames = {'上', '右', '下', '左'};
    
    % 定义奖励
    gridWorld.rewardGoal = 10;      % 到达目标的奖励
    gridWorld.rewardObstacle = -10; % 碰到障碍物的惩罚
    gridWorld.rewardStep = -1;      % 每步的惩罚
end

2. Q-learning算法实现

Q-learning是一种常用的时序差分学习算法,它通过学习状态-动作值函数来找到最优策略。

function [Q, policy, stats] = qLearning(gridWorld, params)
    % Q-learning算法实现
    % 输入参数:
    %   gridWorld - 网格世界环境
    %   params - 算法参数
    % 输出:
    %   Q - 状态-动作值函数
    %   policy - 最优策略
    %   stats - 训练统计信息
    
    % 初始化Q表
    Q = zeros(gridWorld.height, gridWorld.width, 4);
    
    % 初始化统计信息
    stats.episodeLengths = zeros(params.numEpisodes, 1);
    stats.episodeRewards = zeros(params.numEpisodes, 1);
    stats.episodeSuccess = zeros(params.numEpisodes, 1);
    
    % 训练循环
    for episode = 1:params.numEpisodes
        % 初始化状态
        state = gridWorld.start;
        done = false;
        step = 0;
        totalReward = 0;
        
        % ε-贪婪策略的探索率衰减
        epsilon = params.epsilonMin + (params.epsilonMax - params.epsilonMin) * ...
                 exp(-params.epsilonDecay * episode);
        
        % 单个episode循环
        while ~done && step < params.maxSteps
            % 选择动作 (ε-贪婪策略)
            if rand() < epsilon
                action = randi(4); % 随机探索
            else
                [~, action] = max(Q(state(2), state(1), :)); % 利用
            end
            
            % 执行动作,得到下一个状态和奖励
            [nextState, reward, done] = stepEnvironment(gridWorld, state, action);
            
            % Q-learning更新
            currentQ = Q(state(2), state(1), action);
            nextMaxQ = max(Q(nextState(2), nextState(1), :));
            
            % Q值更新公式
            Q(state(2), state(1), action) = currentQ + params.alpha * ...
                (reward + params.gamma * nextMaxQ - currentQ);
            
            % 更新状态和统计信息
            state = nextState;
            step = step + 1;
            totalReward = totalReward + reward;
        end
        
        % 记录统计信息
        stats.episodeLengths(episode) = step;
        stats.episodeRewards(episode) = totalReward;
        stats.episodeSuccess(episode) = done && isequal(state, gridWorld.goal);
        
        % 每100个episode显示进度
        if mod(episode, 100) == 0
            fprintf('Episode %d/%d, Steps: %d, Reward: %.2f, Success: %d\n', ...
                episode, params.numEpisodes, step, totalReward, stats.episodeSuccess(episode));
        end
    end
    
    % 从Q表中提取最优策略
    policy = extractPolicy(Q);
end

function [nextState, reward, done] = stepEnvironment(gridWorld, state, action)
    % 环境模拟函数
    % 输入:
    %   gridWorld - 网格世界环境
    %   state - 当前状态 [x, y]
    %   action - 执行的动作 (1-4)
    % 输出:
    %   nextState - 下一个状态
    %   reward - 获得的奖励
    %   done - 是否终止
    
    % 计算下一个状态
    move = gridWorld.actions(action, :);
    nextState = state + move;
    
    % 检查是否超出边界
    if nextState(1) < 1 || nextState(1) > gridWorld.width || ...
       nextState(2) < 1 || nextState(2) > gridWorld.height
        nextState = state; % 保持原地
    end
    
    % 检查是否到达目标
    if isequal(nextState, gridWorld.goal)
        reward = gridWorld.rewardGoal;
        done = true;
        return;
    end
    
    % 检查是否碰到障碍物
    isObstacle = false;
    for i = 1:size(gridWorld.obstacles, 1)
        if isequal(nextState, gridWorld.obstacles(i, :))
            isObstacle = true;
            break;
        end
    end
    
    if isObstacle
        reward = gridWorld.rewardObstacle;
        done = true;
        nextState = state; % 碰到障碍物,保持原地
        return;
    end
    
    % 普通移动
    reward = gridWorld.rewardStep;
    done = false;
end

function policy = extractPolicy(Q)
    % 从Q表中提取最优策略
    [height, width, ~] = size(Q);
    policy = zeros(height, width);
    
    for y = 1:height
        for x = 1:width
            [~, policy(y, x)] = max(Q(y, x, :));
        end
    end
end

3. SARSA算法实现

SARSA是另一种时序差分学习算法,它与Q-learning的主要区别在于更新策略。

function [Q, policy, stats] = sarsa(gridWorld, params)
    % SARSA算法实现
    % 输入参数:
    %   gridWorld - 网格世界环境
    %   params - 算法参数
    % 输出:
    %   Q - 状态-动作值函数
    %   policy - 最优策略
    %   stats - 训练统计信息
    
    % 初始化Q表
    Q = zeros(gridWorld.height, gridWorld.width, 4);
    
    % 初始化统计信息
    stats.episodeLengths = zeros(params.numEpisodes, 1);
    stats.episodeRewards = zeros(params.numEpisodes, 1);
    stats.episodeSuccess = zeros(params.numEpisodes, 1);
    
    % 训练循环
    for episode = 1:params.numEpisodes
        % 初始化状态
        state = gridWorld.start;
        done = false;
        step = 0;
        totalReward = 0;
        
        % ε-贪婪策略的探索率衰减
        epsilon = params.epsilonMin + (params.epsilonMax - params.epsilonMin) * ...
                 exp(-params.epsilonDecay * episode);
        
        % 选择初始动作
        if rand() < epsilon
            action = randi(4); % 随机探索
        else
            [~, action] = max(Q(state(2), state(1), :)); % 利用
        end
        
        % 单个episode循环
        while ~done && step < params.maxSteps
            % 执行动作,得到下一个状态和奖励
            [nextState, reward, done] = stepEnvironment(gridWorld, state, action);
            
            % 选择下一个动作
            if rand() < epsilon
                nextAction = randi(4); % 随机探索
            else
                [~, nextAction] = max(Q(nextState(2), nextState(1), :)); % 利用
            end
            
            % SARSA更新
            currentQ = Q(state(2), state(1), action);
            nextQ = Q(nextState(2), nextState(1), nextAction);
            
            % Q值更新公式
            Q(state(2), state(1), action) = currentQ + params.alpha * ...
                (reward + params.gamma * nextQ - currentQ);
            
            % 更新状态和动作
            state = nextState;
            action = nextAction;
            step = step + 1;
            totalReward = totalReward + reward;
        end
        
        % 记录统计信息
        stats.episodeLengths(episode) = step;
        stats.episodeRewards(episode) = totalReward;
        stats.episodeSuccess(episode) = done && isequal(state, gridWorld.goal);
        
        % 每100个episode显示进度
        if mod(episode, 100) == 0
            fprintf('Episode %d/%d, Steps: %d, Reward: %.2f, Success: %d\n', ...
                episode, params.numEpisodes, step, totalReward, stats.episodeSuccess(episode));
        end
    end
    
    % 从Q表中提取最优策略
    policy = extractPolicy(Q);
end

4. 可视化函数

为了直观地展示学习过程和结果,我们需要一些可视化函数。

function plotGridWorld(gridWorld, policy, path)
    % 绘制网格世界和策略
    figure;
    hold on;
    axis equal;
    axis([0, gridWorld.width, 0, gridWorld.height]);
    
    % 绘制网格
    for x = 0:gridWorld.width
        plot([x, x], [0, gridWorld.height], 'k-');
    end
    for y = 0:gridWorld.height
        plot([0, gridWorld.width], [y, y], 'k-');
    end
    
    % 绘制起点
    rectangle('Position', [gridWorld.start(1)-0.5, gridWorld.start(2)-0.5, 1, 1], ...
              'FaceColor', 'g', 'EdgeColor', 'none');
    text(gridWorld.start(1), gridWorld.start(2), '起点', ...
         'HorizontalAlignment', 'center', 'VerticalAlignment', 'middle');
    
    % 绘制终点
    rectangle('Position', [gridWorld.goal(1)-0.5, gridWorld.goal(2)-0.5, 1, 1], ...
              'FaceColor', 'r', 'EdgeColor', 'none');
    text(gridWorld.goal(1), gridWorld.goal(2), '终点', ...
         'HorizontalAlignment', 'center', 'VerticalAlignment', 'middle');
    
    % 绘制障碍物
    for i = 1:size(gridWorld.obstacles, 1)
        obs = gridWorld.obstacles(i, :);
        rectangle('Position', [obs(1)-0.5, obs(2)-0.5, 1, 1], ...
                  'FaceColor', 'k', 'EdgeColor', 'none');
    end
    
    % 绘制策略箭头
    if nargin >= 2
        arrowScale = 0.3;
        for y = 1:gridWorld.height
            for x = 1:gridWorld.width
                % 跳过起点、终点和障碍物
                if isequal([x, y], gridWorld.start) || ...
                   isequal([x, y], gridWorld.goal) || ...
                   isObstacle(gridWorld, [x, y])
                    continue;
                end
                
                action = policy(y, x);
                dx = gridWorld.actions(action, 1) * arrowScale;
                dy = gridWorld.actions(action, 2) * arrowScale;
                
                quiver(x, y, dx, dy, 'AutoScale', 'off', ...
                       'MaxHeadSize', 1, 'LineWidth', 1.5, 'Color', 'b');
            end
        end
    end
    
    % 绘制路径
    if nargin >= 3 && ~isempty(path)
        pathX = path(:, 1);
        pathY = path(:, 2);
        plot(pathX, pathY, 'm-', 'LineWidth', 2);
        plot(pathX, pathY, 'mo', 'MarkerSize', 6, 'MarkerFaceColor', 'm');
    end
    
    hold off;
    title('网格世界与学习策略');
end

function plotLearningStats(stats, algorithmName)
    % 绘制学习统计信息
    figure;
    
    % 绘制每episode的步数
    subplot(3, 1, 1);
    plot(stats.episodeLengths);
    xlabel('Episode');
    ylabel('步数');
    title(sprintf('%s - 每Episode步数', algorithmName));
    grid on;
    
    % 绘制每episode的奖励
    subplot(3, 1, 2);
    plot(stats.episodeRewards);
    xlabel('Episode');
    ylabel('奖励');
    title(sprintf('%s - 每Episode奖励', algorithmName));
    grid on;
    
    % 绘制成功率(滑动平均)
    subplot(3, 1, 3);
    windowSize = 50;
    successRate = movmean(stats.episodeSuccess, windowSize);
    plot(successRate);
    xlabel('Episode');
    ylabel('成功率');
    title(sprintf('%s - 成功率 (%d-episode移动平均)', algorithmName, windowSize));
    grid on;
    
    % 调整布局
    set(gcf, 'Position', [100, 100, 800, 600]);
end

function path = findPath(gridWorld, policy)
    % 根据策略找到从起点到终点的路径
    state = gridWorld.start;
    path = state;
    visited = zeros(gridWorld.height, gridWorld.width);
    visited(state(2), state(1)) = 1;
    
    maxSteps = gridWorld.width * gridWorld.height * 2; % 防止无限循环
    step = 0;
    
    while ~isequal(state, gridWorld.goal) && step < maxSteps
        action = policy(state(2), state(1));
        move = gridWorld.actions(action, :);
        nextState = state + move;
        
        % 检查是否有效移动
        if nextState(1) < 1 || nextState(1) > gridWorld.width || ...
           nextState(2) < 1 || nextState(2) > gridWorld.height || ...
           isObstacle(gridWorld, nextState)
            break; % 无效移动,退出循环
        end
        
        % 检查是否访问过该状态(防止循环)
        if visited(nextState(2), nextState(1))
            break; % 已经访问过,退出循环
        end
        
        state = nextState;
        path = [path; state];
        visited(state(2), state(1)) = 1;
        step = step + 1;
    end
end

function result = isObstacle(gridWorld, state)
    % 检查给定状态是否为障碍物
    result = false;
    for i = 1:size(gridWorld.obstacles, 1)
        if isequal(state, gridWorld.obstacles(i, :))
            result = true;
            return;
        end
    end
end

5. 主函数和参数设置

最后,我们创建一个主函数来整合所有组件,并设置算法参数。

function main()
    % 主函数:时序差分学习路径规划仿真
    
    % 创建网格世界环境
    width = 10;
    height = 10;
    start = [1, 1];
    goal = [10, 10];
    obstacles = [3, 3; 3, 4; 3, 5; 4, 5; 5, 5; 6, 5; 7, 5; 8, 5; 8, 4; 8, 3];
    
    gridWorld = createGridWorld(width, height, start, goal, obstacles);
    
    % 设置算法参数
    params.alpha = 0.1;          % 学习率
    params.gamma = 0.9;          % 折扣因子
    params.epsilonMax = 0.9;     % 初始探索率
    params.epsilonMin = 0.1;     % 最小探索率
    params.epsilonDecay = 0.001; % 探索率衰减率
    params.numEpisodes = 1000;   % 训练episode数量
    params.maxSteps = 100;       % 每个episode最大步数
    
    % 训练Q-learning算法
    fprintf('训练Q-learning算法...\n');
    [Q_qlearning, policy_qlearning, stats_qlearning] = qLearning(gridWorld, params);
    
    % 训练SARSA算法
    fprintf('训练SARSA算法...\n');
    [Q_sarsa, policy_sarsa, stats_sarsa] = sarsa(gridWorld, params);
    
    % 可视化结果
    % 1. 绘制学习曲线
    plotLearningStats(stats_qlearning, 'Q-learning');
    plotLearningStats(stats_sarsa, 'SARSA');
    
    % 2. 绘制网格世界和策略
    path_qlearning = findPath(gridWorld, policy_qlearning);
    path_sarsa = findPath(gridWorld, policy_sarsa);
    
    plotGridWorld(gridWorld, policy_qlearning, path_qlearning);
    title('Q-learning策略与路径');
    
    plotGridWorld(gridWorld, policy_sarsa, path_sarsa);
    title('SARSA策略与路径');
    
    % 3. 比较两种算法的性能
    fprintf('算法性能比较:\n');
    fprintf('Q-learning - 平均步数: %.2f, 平均奖励: %.2f, 成功率: %.2f%%\n', ...
        mean(stats_qlearning.episodeLengths), ...
        mean(stats_qlearning.episodeRewards), ...
        mean(stats_qlearning.episodeSuccess) * 100);
    
    fprintf('SARSA - 平均步数: %.2f, 平均奖励: %.2f, 成功率: %.2f%%\n', ...
        mean(stats_sarsa.episodeLengths), ...
        mean(stats_sarsa.episodeRewards), ...
        mean(stats_sarsa.episodeSuccess) * 100);
    
    % 4. 显示最优路径
    fprintf('Q-learning路径长度: %d\n', size(path_qlearning, 1));
    fprintf('SARSA路径长度: %d\n', size(path_sarsa, 1));
end

6. 运行仿真

要运行这个仿真,只需在MATLAB命令窗口中调用主函数:

main();

推荐代码 时序差分学习做路径规划的仿真 www.3dddown.com/cna/50855.html

7. 扩展功能

你可以进一步扩展这个仿真,例如:

  1. 添加更多算法:实现其他时序差分学习算法,如Expected SARSA、Double Q-learning等。

  2. 复杂环境:创建更复杂的环境,如随机生成的迷宫、动态障碍物等。

  3. 函数逼近:对于大型状态空间,可以使用函数逼近方法(如神经网络)代替Q表。

  4. 多智能体:扩展为多智能体路径规划问题。

  5. 实时可视化:在训练过程中实时显示智能体的学习和探索过程。

posted @ 2025-12-19 15:56  令小飞  阅读(0)  评论(0)    收藏  举报