强化学习之DQN游戏训练

2.1 状态输入与特征构建序列

DQN的输入通常是连续的游戏画面帧。为了捕捉时间动态(如物体的速度、方向),原始论文将最近的4帧画面堆叠为一个状态张量。
深度Q网络(DQN)智能体实现
包含网络结构、经验回放、训练逻辑等核心组件
class DQN(nn.Module):
DQN网络结构 - 适用于Atari游戏的卷积网络版本

def __init__(self, input_shape, n_actions):
       
        参数:
            input_shape: 输入状态形状 (channels, height, width)
            n_actions: 动作空间大小
        super(DQN, self).__init__()
       
        # 卷积特征提取层
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )
        
        # 计算卷积层输出的特征维度
        conv_out_size = self._get_conv_out(input_shape)
        
        # 全连接价值评估层
        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )
    
    def _get_conv_out(self, shape):
        """计算卷积层输出的特征数量"""
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))
    
    def forward(self, x):
        """前向传播"""
        conv_out = self.conv(x)
        conv_out_flat = conv_out.view(conv_out.size(0), -1)
        return self.fc(conv_out_flat)


class SimpleDQN(nn.Module):
    """简化版DQN网络 - 适用于CartPole等状态为向量的环境"""
    
    def __init__(self, state_dim, n_actions):
        super(SimpleDQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions)
        )
    
    def forward(self, x):
        return self.net(x)


class ReplayBuffer:
    """经验回放缓冲区"""
    
    def __init__(self, capacity=100000):
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)
        self.position = 0
    
    def push(self, state, action, reward, next_state, done):
        """存入一条经验"""
        experience = (state, action, reward, next_state, done)
        self.buffer.append(experience)
    
    def sample(self, batch_size):
        """随机采样一批经验"""
        if len(self.buffer) < batch_size:
            return None
        
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        
        return (
            torch.stack(states),
            torch.LongTensor(actions),
            torch.FloatTensor(rewards),
            torch.stack(next_states),
            torch.FloatTensor(dones)
        )
    
    def __len__(self):
        return len(self.buffer)


class DQNAgent:
    """DQN智能体"""
    
    def __init__(self, state_dim, n_actions, use_image_input=False, 
                 lr=1e-4, gamma=0.99, tau=0.005, device='cpu'):
        """
        参数:
            state_dim: 状态维度
            n_actions: 动作数量
            use_image_input: 是否使用图像输入
            lr: 学习率
            gamma: 折扣因子
            tau: 目标网络软更新参数
            device: 计算设备
        """
        self.device = device
        self.n_actions = n_actions
        self.gamma = gamma
        self.tau = tau
        self.use_image_input = use_image_input
        
        # 创建网络
        if use_image_input:
            self.online_net = DQN(state_dim, n_actions).to(device)
            self.target_net = DQN(state_dim, n_actions).to(device)
        else:
            self.online_net = SimpleDQN(state_dim, n_actions).to(device)
            self.target_net = SimpleDQN(state_dim, n_actions).to(device)
        
        # 同步目标网络参数
        self.target_net.load_state_dict(self.online_net.state_dict())
        self.target_net.eval()  # 目标网络设置为评估模式
        
        # 优化器
        self.optimizer = optim.Adam(self.online_net.parameters(), lr=lr)
        
        # 经验回放缓冲区
        self.memory = ReplayBuffer(capacity=10000)
        
        # 训练步数计数器
        self.train_step = 0
    
    def select_action(self, state, epsilon=0.1):
        """使用ε-greedy策略选择动作"""
        if random.random() < epsilon:
            # 探索:随机选择动作
            return random.randint(0, self.n_actions - 1)
        else:
            # 利用:选择Q值最高的动作
            with torch.no_grad():
                state = state.unsqueeze(0) if state.dim() == 3 else state
                state = state.to(self.device)
                q_values = self.online_net(state)
                return q_values.argmax().item()
    
    def store_experience(self, state, action, reward, next_state, done):
        """存储经验到回放缓冲区"""
        self.memory.push(state, action, reward, next_state, done)
    
    def learn(self, batch_size=64):
        """从经验中学习"""
        # 采样批次
        batch = self.memory.sample(batch_size)
        if batch is None:
            return 0
        
        states, actions, rewards, next_states, dones = batch
        
        # 转换为设备张量
        states = states.to(self.device)
        actions = actions.to(self.device)
        rewards = rewards.to(self.device)
        next_states = next_states.to(self.device)
        dones = dones.to(self.device)
        
        # 计算当前Q值
        current_q_values = self.online_net(states).gather(1, actions.unsqueeze(1))
        
        # 计算目标Q值
        with torch.no_grad():
            next_q_values = self.target_net(next_states).max(1)[0]
            target_q_values = rewards + (self.gamma * next_q_values * (1 - dones))
        
        # 计算损失
        loss = F.smooth_l1_loss(current_q_values, target_q_values.unsqueeze(1))
        
        # 优化步骤
        self.optimizer.zero_grad()
        loss.backward()
        
        # 梯度裁剪,防止梯度爆炸
        torch.nn.utils.clip_grad_norm_(self.online_net.parameters(), 10.0)
        
        self.optimizer.step()
        
        # 软更新目标网络
        self.soft_update_target_network()
        
        self.train_step += 1
        
        return loss.item()
    
    def soft_update_target_network(self):
        """软更新目标网络参数"""
        for target_param, online_param in zip(self.target_net.parameters(), 
                                              self.online_net.parameters()):
            target_param.data.copy_(
                self.tau * online_param.data + (1.0 - self.tau) * target_param.data
            )
    
    def hard_update_target_network(self):
        """硬更新目标网络参数"""
        self.target_net.load_state_dict(self.online_net.state_dict())
    
    def save(self, path):
        """保存模型"""
        torch.save({
            'online_net_state_dict': self.online_net.state_dict(),
            'target_net_state_dict': self.target_net.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'train_step': self.train_step,
        }, path)
    
    def load(self, path):
        """加载模型"""
        checkpoint = torch.load(path, map_location=self.device)
        self.online_net.load_state_dict(checkpoint['online_net_state_dict'])
        self.target_net.load_state_dict(checkpoint['target_net_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.train_step = checkpoint['train_step']


class StateProcessor:
    """状态处理器 - 处理图像输入"""
    
    def __init__(self, frame_stack=4, img_size=(84, 84)):
        self.frame_stack = frame_stack
        self.img_size = img_size
        self.frame_buffer = deque(maxlen=frame_stack)
    
    def reset(self):
        """重置帧缓冲区"""
        self.frame_buffer.clear()
    
    def process_state(self, observation):
        """
        处理原始观测,转换为网络输入格式
        参数:
            observation: 原始观测
        返回:
            state_tensor: 处理后的状态张量
        """
        # 转换为灰度图
        if len(observation.shape) == 3:
            # RGB图像转灰度
            gray = np.dot(observation[..., :3], [0.2989, 0.5870, 0.1140])
        else:
            gray = observation
        
        # 调整大小
        from skimage.transform import resize
        resized = resize(gray, self.img_size, mode='reflect', anti_aliasing=True)
        
        # 归一化到[0, 1]
        resized = resized.astype(np.float32) / 255.0
        
        # 添加到帧缓冲区
        self.frame_buffer.append(resized)
        
        # 如果缓冲区不满,用第一帧填充
        while len(self.frame_buffer) < self.frame_stack:
            self.frame_buffer.appendleft(resized)
        
        # 堆叠帧
        state = np.stack(self.frame_buffer, axis=0)
        
        # 转换为张量
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        
        return state_tensor


class VectorStateProcessor:
    """向量状态处理器 - 处理向量输入(如CartPole)"""
    
    def __init__(self):
        pass
    
    def reset(self):
        """重置处理器"""
        pass
    
    def process_state(self, observation):
        """
        处理向量观测
        参数:
            observation: 原始观测向量
        返回:
            state_tensor: 处理后的状态张量
        """
        # 直接转换为张量
        state_tensor = torch.FloatTensor(observation).unsqueeze(0)
        return state_tensor`

2.2 特征提取主干网络 (Feature Extraction Backbone)

这是DQN的核心,由多个卷积层构成,负责从原始像素中提取高级特征。
class DQNBackbone(nn.Module):
DQN特征提取主干网络 (Atari版本)

def __init__(self, input_channels=4):
    super(DQNBackbone, self).__init__()
    self.conv_layers = nn.Sequential(
        # 卷积层 1: 输入4通道,输出32通道
        nn.Conv2d(in_channels=input_channels, out_channels=32, kernel_size=8, stride=4),
        # kernel_size=8, stride=4 -> 输出尺寸: (84-8)/4 +1 = 20
        nn.ReLU(inplace=True),
        # 卷积层 2: 输入32通道,输出64通道
        nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
        # kernel_size=4, stride=2 -> 输出尺寸: (20-4)/2 +1 = 9
        nn.ReLU(inplace=True),
        # 卷积层 3: 输入64通道,输出64通道
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
        # kernel_size=3, stride=1 -> 输出尺寸: (9-3)/1 +1 = 7
        nn.ReLU(inplace=True)
    )
    # 计算卷积层输出的特征数量
    self._feature_dim = self._calculate_conv_output(input_channels)

def _calculate_conv_output(self, input_channels):
    """前向传播一个零张量以计算卷积层输出尺寸"""
    with torch.no_grad():
        dummy_input = torch.zeros(1, input_channels, 84, 84)
        dummy_output = self.conv_layers(dummy_input)
        return int(np.prod(dummy_output.size()))

def forward(self, x):
    """
    Args:
        x: 输入张量,形状为 (batch_size, 4, 84, 84)
    Returns:
        features: 展平的特征向量,形状为 (batch_size, 512)
    """
    conv_out = self.conv_layers(x)  # 形状: (batch_size, 64, 7, 7)
    print(f"卷积层输出形状: {conv_out.shape}")
    features = conv_out.view(conv_out.size(0), -1)  # 展平 -> (batch_size, 64*7*7=3136)
    print(f"展平后特征形状: {features.shape}")
    return features

2.3 价值评估头部 (Value Estimation Head)

将提取的特征映射到每个可能动作的Q值上。
class DQNValueHead(nn.Module):
DQN价值评估头部。
输入: (batch_size, 3136) 的特征向量
输出: (batch_size, n_actions) 的每个动作的Q值

def __init__(self, feature_dim, n_actions):
    super(DQNValueHead, self).__init__()
    self.fc_layers = nn.Sequential(
        nn.Linear(in_features=feature_dim, out_features=512),
        nn.ReLU(inplace=True),
        nn.Linear(in_features=512, out_features=n_actions)  # 输出每个动作的原始Q值
    )

def forward(self, features):
    """
    Args:
        features: 来自主干网络的特征,形状为 (batch_size, feature_dim)
    Returns:
        q_values: 每个动作的Q值,形状为 (batch_size, n_actions)
    """
    q_values = self.fc_layers(features)
    print(f"Q值输出形状: {q_values.shape}")
    return q_values

2.4 经验回放与批采样 (Experience Replay & Batch Sampling)

这是DQN稳定训练的关键机制,类比于Swin Transformer中的“窗口划分”,它从大量交互经验中“划分”出小批量样本用于学习。
class ReplayBuffer:
经验回放缓冲区。
功能: 存储智能体的交互经验 (s, a, r, s', done),并随机采样小批量以打破数据相关性。

def __init__(self, capacity=100000):
    self.capacity = capacity
    self.buffer = []
    self.position = 0

def push(self, state, action, reward, next_state, done):
    """存入一条经验"""
    if len(self.buffer) < self.capacity:
        self.buffer.append(None)
    self.buffer[self.position] = (state, action, reward, next_state, done)
    self.position = (self.position + 1) % self.capacity

def sample(self, batch_size):
    """
    随机采样一个小批量的经验。
    Args:
        batch_size: 批次大小
    Returns:
        批量的 (states, actions, rewards, next_states, dones)
    """
    indices = np.random.choice(len(self.buffer), batch_size, replace=False)
    states, actions, rewards, next_states, dones = zip(*[self.buffer[idx] for idx in indices])
    # 转换为PyTorch张量
    return (torch.stack(states),
            torch.LongTensor(actions),
            torch.FloatTensor(rewards),
            torch.stack(next_states),
            torch.FloatTensor(dones))

def __len__(self):
    return len(self.buffer)

2.5 目标网络机制 (Target Network Mechanism)

类比于Swin Transformer中的“Shifted Window”带来的跨窗口信息交互,目标网络通过提供相对稳定的目标Q值,解决了自举法中的训练不稳定问题。

class DQNAgent:

DQN智能体,整合在线网络、目标网络及学习逻辑。

def __init__(self, state_channels, n_actions, learning_rate=1e-4, gamma=0.99, tau=0.005):
    self.gamma = gamma  # 未来奖励的折扣因子
    self.tau = tau      # 目标网络软更新参数

    # 在线网络 (Online Network): 用于选择动作和计算当前Q值
    self.online_backbone = DQNBackbone(input_channels=state_channels)
    self.online_head = DQNValueHead(feature_dim=self.online_backbone._feature_dim, n_actions=n_actions)
    self.online_network = nn.Sequential(self.online_backbone, self.online_head)
    self.optimizer = torch.optim.Adam(self.online_network.parameters(), lr=learning_rate)

    # 目标网络 (Target Network): 用于计算目标Q值,参数定期从在线网络同步
    self.target_backbone = DQNBackbone(input_channels=state_channels)
    self.target_head = DQNValueHead(feature_dim=self.target_backbone._feature_dim, n_actions=n_actions)
    self.target_network = nn.Sequential(self.target_backbone, self.target_head)
    # 初始化时,目标网络参数与在线网络相同
    self._hard_update_target_network()

def _hard_update_target_network(self):
    """硬更新:将在线网络的参数完全复制给目标网络"""
    self.target_network.load_state_dict(self.online_network.state_dict())
    print("目标网络参数已硬更新。")

def _soft_update_target_network(self):
    """软更新:缓慢地将在线网络的参数混合到目标网络 (Polyak平均)"""
    for target_param, online_param in zip(self.target_network.parameters(), self.online_network.parameters()):
        target_param.data.copy_(self.tau * online_param.data + (1.0 - self.tau) * target_param.data)

def compute_loss(self, batch):
    """
    计算一个批次的损失 (Huber Loss / Smooth L1 Loss)。
    Args:
        batch: 来自ReplayBuffer的样本元组 (s, a, r, s', done)
    Returns:
        loss: 计算出的损失标量
    """
    states, actions, rewards, next_states, dones = batch

    # 1. 使用在线网络计算当前状态-动作对的Q值
    current_q_values = self.online_network(states).gather(1, actions.unsqueeze(1))  # shape: (batch, 1)
    print(f"当前Q值形状: {current_q_values.shape}")

    # 2. 使用目标网络计算下一状态的最大Q值 (Double DQN中会使用在线网络选择动作)
    with torch.no_grad():  # 目标值计算不应产生梯度
        next_q_values = self.target_network(next_states).max(1)[0]  # shape: (batch,)
        target_q_values = rewards + (self.gamma * next_q_values * (1 - dones))
        print(f"目标Q值形状: {target_q_values.unsqueeze(1).shape}")

    # 3. 计算Huber损失 (对离群值比MSE更鲁棒)
    loss = nn.functional.smooth_l1_loss(current_q_values, target_q_values.unsqueeze(1))
    print(f"计算得到的损失: {loss.item():.4f}")
    return loss

def select_action(self, state, epsilon):
    """
    使用 ε-greedy 策略选择动作。
    Args:
        state: 当前状态
        epsilon: 探索概率
    Returns:
        action: 选择的动作索引
    """
    if np.random.rand() < epsilon:
        # 探索:随机选择动作
        return np.random.randint(self.online_head.fc_layers[-1].out_features)
    else:
        # 利用:选择当前估计Q值最高的动作
        with torch.no_grad():
            q_values = self.online_network(state)
            print(f"选择动作时Q值形状: {q_values.shape}")
            return q_values.argmax().item()

2.6 分层训练流程 (Staged Training Pipeline)

典型的DQN训练流程包含多个层次分明的阶段,类似于Swin Transformer的各个Stage。

def dqn_training_pipeline(env, agent, total_episodes=1000):
    """
    DQN分层训练流程。
    """
    replay_buffer = ReplayBuffer(capacity=10000)
    epsilon_start, epsilon_end, epsilon_decay = 1.0, 0.01, 0.995
    epsilon = epsilon_start

    for episode in range(total_episodes):
        state = env.reset()
        state_processor = StateProcessor()
        state = state_processor.process_state(state)
        total_reward = 0

        while True:
            # Stage 1: 交互与数据收集
            action = agent.select_action(state, epsilon)
            next_observation, reward, done, _ = env.step(action)
            next_state = state_processor.process_state(next_observation)

            # 将经验存入缓冲区
            replay_buffer.push(state.squeeze(0), action, reward, next_state.squeeze(0), done)
            state = next_state
            total_reward += reward

            # Stage 2: 网络更新 (当缓冲区足够大时)
            if len(replay_buffer) > 64:  # 类比于Swin Transformer中执行后续Block
                batch = replay_buffer.sample(batch_size=32)
                loss = agent.compute_loss(batch)

                agent.optimizer.zero_grad()
                loss.backward()
                # 可选:梯度裁剪,防止梯度爆炸
                torch.nn.utils.clip_grad_norm_(agent.online_network.parameters(), max_norm=10.0)
                agent.optimizer.step()

                # Stage 3: 软更新目标网络
                agent._soft_update_target_network()

            if done:
                break

        # 衰减探索率
        epsilon = max(epsilon_end, epsilon * epsilon_decay)
        print(f"回合 {episode+1}: 总奖励 = {total_reward:.2f}, 探索率 = {epsilon:.3f}")

    print("训练完成!")

2.7 输出与应用 (Output & Application)

训练好的DQN模型可以用于决策,其输出是对每个动作的价值评估。

class TrainedDQN:
    """
    训练完成的DQN模型,用于实际决策。
    """
    def __init__(self, model_path, n_actions):
        checkpoint = torch.load(model_path)
        self.model = DQNAgent(state_channels=4, n_actions=n_actions)
        self.model.online_network.load_state_dict(checkpoint['online_network_state_dict'])
        self.model.online_network.eval()  # 设置为评估模式

    def predict(self, state):
        """
        给定状态,输出所有动作的Q值。
        Args:
            state: 预处理后的状态张量 (1, 4, 84, 84)
        Returns:
            q_values: 所有动作的Q值 (1, n_actions)
            best_action: 最佳动作索引
        """
        with torch.no_grad():
            q_values = self.model.online_network(state)
            best_action = q_values.argmax().item()
        return q_values.numpy(), best_action

    def visualize_decision(self, state):
        """
        可视化网络对当前状态的决策依据 (可选)。
        可以通过类激活映射等方法,理解网络关注图像的哪部分。
        """
        # 此处为高级功能示意,实现略。
        pass

结构对比总结

Swin Transformer (视觉) DQN (强化学习) 核心目的类比
Patch Partition Frame Stacking 将原始高维输入(图像)转换为模型可处理的序列化/结构化输入。
Window Partition Experience Sampling 将全局信息/经验池划分为更小的、可独立/并行处理的局部单元(窗口/批次)。
W-MSA Online Q-value Estimation 在局部/当前范围内进行信息聚合与计算(窗口内自注意力 / 计算当前Q值)。
SW-MSA Target Q-value Calculation 引入跨局部/未来的信息交互与约束(跨窗口连接 / 基于未来奖励的目标Q值),防止模型陷入局部最优或短视。
Patch Merging Feature Downsampling in Conv 进行下采样,聚合多尺度/多层次的语义特征,增大感受野。
Hierarchical Stages Training Episodes & Updates 分阶段/分层次地处理信息和学习,逐步优化模型表示或策略。

通过以上结构分解,DQN从原始像素输入到最终动作输出的完整数据处理与学习流程,被清晰地划分为多个模块化的步骤,每个步骤都有其明确的数学和工程目的。这种模块化视角有助于深入理解深度强化学习模型的设计思想。

posted @ 2025-12-21 22:33  Zzzzzr1  阅读(8)  评论(0)    收藏  举报