强化学习之DQN游戏训练
2.1 状态输入与特征构建序列
DQN的输入通常是连续的游戏画面帧。为了捕捉时间动态(如物体的速度、方向),原始论文将最近的4帧画面堆叠为一个状态张量。
深度Q网络(DQN)智能体实现
包含网络结构、经验回放、训练逻辑等核心组件
class DQN(nn.Module):
DQN网络结构 - 适用于Atari游戏的卷积网络版本
def __init__(self, input_shape, n_actions):
参数:
input_shape: 输入状态形状 (channels, height, width)
n_actions: 动作空间大小
super(DQN, self).__init__()
# 卷积特征提取层
self.conv = nn.Sequential(
nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU()
)
# 计算卷积层输出的特征维度
conv_out_size = self._get_conv_out(input_shape)
# 全连接价值评估层
self.fc = nn.Sequential(
nn.Linear(conv_out_size, 512),
nn.ReLU(),
nn.Linear(512, n_actions)
)
def _get_conv_out(self, shape):
"""计算卷积层输出的特征数量"""
o = self.conv(torch.zeros(1, *shape))
return int(np.prod(o.size()))
def forward(self, x):
"""前向传播"""
conv_out = self.conv(x)
conv_out_flat = conv_out.view(conv_out.size(0), -1)
return self.fc(conv_out_flat)
class SimpleDQN(nn.Module):
"""简化版DQN网络 - 适用于CartPole等状态为向量的环境"""
def __init__(self, state_dim, n_actions):
super(SimpleDQN, self).__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, n_actions)
)
def forward(self, x):
return self.net(x)
class ReplayBuffer:
"""经验回放缓冲区"""
def __init__(self, capacity=100000):
self.capacity = capacity
self.buffer = deque(maxlen=capacity)
self.position = 0
def push(self, state, action, reward, next_state, done):
"""存入一条经验"""
experience = (state, action, reward, next_state, done)
self.buffer.append(experience)
def sample(self, batch_size):
"""随机采样一批经验"""
if len(self.buffer) < batch_size:
return None
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return (
torch.stack(states),
torch.LongTensor(actions),
torch.FloatTensor(rewards),
torch.stack(next_states),
torch.FloatTensor(dones)
)
def __len__(self):
return len(self.buffer)
class DQNAgent:
"""DQN智能体"""
def __init__(self, state_dim, n_actions, use_image_input=False,
lr=1e-4, gamma=0.99, tau=0.005, device='cpu'):
"""
参数:
state_dim: 状态维度
n_actions: 动作数量
use_image_input: 是否使用图像输入
lr: 学习率
gamma: 折扣因子
tau: 目标网络软更新参数
device: 计算设备
"""
self.device = device
self.n_actions = n_actions
self.gamma = gamma
self.tau = tau
self.use_image_input = use_image_input
# 创建网络
if use_image_input:
self.online_net = DQN(state_dim, n_actions).to(device)
self.target_net = DQN(state_dim, n_actions).to(device)
else:
self.online_net = SimpleDQN(state_dim, n_actions).to(device)
self.target_net = SimpleDQN(state_dim, n_actions).to(device)
# 同步目标网络参数
self.target_net.load_state_dict(self.online_net.state_dict())
self.target_net.eval() # 目标网络设置为评估模式
# 优化器
self.optimizer = optim.Adam(self.online_net.parameters(), lr=lr)
# 经验回放缓冲区
self.memory = ReplayBuffer(capacity=10000)
# 训练步数计数器
self.train_step = 0
def select_action(self, state, epsilon=0.1):
"""使用ε-greedy策略选择动作"""
if random.random() < epsilon:
# 探索:随机选择动作
return random.randint(0, self.n_actions - 1)
else:
# 利用:选择Q值最高的动作
with torch.no_grad():
state = state.unsqueeze(0) if state.dim() == 3 else state
state = state.to(self.device)
q_values = self.online_net(state)
return q_values.argmax().item()
def store_experience(self, state, action, reward, next_state, done):
"""存储经验到回放缓冲区"""
self.memory.push(state, action, reward, next_state, done)
def learn(self, batch_size=64):
"""从经验中学习"""
# 采样批次
batch = self.memory.sample(batch_size)
if batch is None:
return 0
states, actions, rewards, next_states, dones = batch
# 转换为设备张量
states = states.to(self.device)
actions = actions.to(self.device)
rewards = rewards.to(self.device)
next_states = next_states.to(self.device)
dones = dones.to(self.device)
# 计算当前Q值
current_q_values = self.online_net(states).gather(1, actions.unsqueeze(1))
# 计算目标Q值
with torch.no_grad():
next_q_values = self.target_net(next_states).max(1)[0]
target_q_values = rewards + (self.gamma * next_q_values * (1 - dones))
# 计算损失
loss = F.smooth_l1_loss(current_q_values, target_q_values.unsqueeze(1))
# 优化步骤
self.optimizer.zero_grad()
loss.backward()
# 梯度裁剪,防止梯度爆炸
torch.nn.utils.clip_grad_norm_(self.online_net.parameters(), 10.0)
self.optimizer.step()
# 软更新目标网络
self.soft_update_target_network()
self.train_step += 1
return loss.item()
def soft_update_target_network(self):
"""软更新目标网络参数"""
for target_param, online_param in zip(self.target_net.parameters(),
self.online_net.parameters()):
target_param.data.copy_(
self.tau * online_param.data + (1.0 - self.tau) * target_param.data
)
def hard_update_target_network(self):
"""硬更新目标网络参数"""
self.target_net.load_state_dict(self.online_net.state_dict())
def save(self, path):
"""保存模型"""
torch.save({
'online_net_state_dict': self.online_net.state_dict(),
'target_net_state_dict': self.target_net.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'train_step': self.train_step,
}, path)
def load(self, path):
"""加载模型"""
checkpoint = torch.load(path, map_location=self.device)
self.online_net.load_state_dict(checkpoint['online_net_state_dict'])
self.target_net.load_state_dict(checkpoint['target_net_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.train_step = checkpoint['train_step']
class StateProcessor:
"""状态处理器 - 处理图像输入"""
def __init__(self, frame_stack=4, img_size=(84, 84)):
self.frame_stack = frame_stack
self.img_size = img_size
self.frame_buffer = deque(maxlen=frame_stack)
def reset(self):
"""重置帧缓冲区"""
self.frame_buffer.clear()
def process_state(self, observation):
"""
处理原始观测,转换为网络输入格式
参数:
observation: 原始观测
返回:
state_tensor: 处理后的状态张量
"""
# 转换为灰度图
if len(observation.shape) == 3:
# RGB图像转灰度
gray = np.dot(observation[..., :3], [0.2989, 0.5870, 0.1140])
else:
gray = observation
# 调整大小
from skimage.transform import resize
resized = resize(gray, self.img_size, mode='reflect', anti_aliasing=True)
# 归一化到[0, 1]
resized = resized.astype(np.float32) / 255.0
# 添加到帧缓冲区
self.frame_buffer.append(resized)
# 如果缓冲区不满,用第一帧填充
while len(self.frame_buffer) < self.frame_stack:
self.frame_buffer.appendleft(resized)
# 堆叠帧
state = np.stack(self.frame_buffer, axis=0)
# 转换为张量
state_tensor = torch.FloatTensor(state).unsqueeze(0)
return state_tensor
class VectorStateProcessor:
"""向量状态处理器 - 处理向量输入(如CartPole)"""
def __init__(self):
pass
def reset(self):
"""重置处理器"""
pass
def process_state(self, observation):
"""
处理向量观测
参数:
observation: 原始观测向量
返回:
state_tensor: 处理后的状态张量
"""
# 直接转换为张量
state_tensor = torch.FloatTensor(observation).unsqueeze(0)
return state_tensor`
2.2 特征提取主干网络 (Feature Extraction Backbone)
这是DQN的核心,由多个卷积层构成,负责从原始像素中提取高级特征。
class DQNBackbone(nn.Module):
DQN特征提取主干网络 (Atari版本)
def __init__(self, input_channels=4):
super(DQNBackbone, self).__init__()
self.conv_layers = nn.Sequential(
# 卷积层 1: 输入4通道,输出32通道
nn.Conv2d(in_channels=input_channels, out_channels=32, kernel_size=8, stride=4),
# kernel_size=8, stride=4 -> 输出尺寸: (84-8)/4 +1 = 20
nn.ReLU(inplace=True),
# 卷积层 2: 输入32通道,输出64通道
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
# kernel_size=4, stride=2 -> 输出尺寸: (20-4)/2 +1 = 9
nn.ReLU(inplace=True),
# 卷积层 3: 输入64通道,输出64通道
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
# kernel_size=3, stride=1 -> 输出尺寸: (9-3)/1 +1 = 7
nn.ReLU(inplace=True)
)
# 计算卷积层输出的特征数量
self._feature_dim = self._calculate_conv_output(input_channels)
def _calculate_conv_output(self, input_channels):
"""前向传播一个零张量以计算卷积层输出尺寸"""
with torch.no_grad():
dummy_input = torch.zeros(1, input_channels, 84, 84)
dummy_output = self.conv_layers(dummy_input)
return int(np.prod(dummy_output.size()))
def forward(self, x):
"""
Args:
x: 输入张量,形状为 (batch_size, 4, 84, 84)
Returns:
features: 展平的特征向量,形状为 (batch_size, 512)
"""
conv_out = self.conv_layers(x) # 形状: (batch_size, 64, 7, 7)
print(f"卷积层输出形状: {conv_out.shape}")
features = conv_out.view(conv_out.size(0), -1) # 展平 -> (batch_size, 64*7*7=3136)
print(f"展平后特征形状: {features.shape}")
return features
2.3 价值评估头部 (Value Estimation Head)
将提取的特征映射到每个可能动作的Q值上。
class DQNValueHead(nn.Module):
DQN价值评估头部。
输入: (batch_size, 3136) 的特征向量
输出: (batch_size, n_actions) 的每个动作的Q值
def __init__(self, feature_dim, n_actions):
super(DQNValueHead, self).__init__()
self.fc_layers = nn.Sequential(
nn.Linear(in_features=feature_dim, out_features=512),
nn.ReLU(inplace=True),
nn.Linear(in_features=512, out_features=n_actions) # 输出每个动作的原始Q值
)
def forward(self, features):
"""
Args:
features: 来自主干网络的特征,形状为 (batch_size, feature_dim)
Returns:
q_values: 每个动作的Q值,形状为 (batch_size, n_actions)
"""
q_values = self.fc_layers(features)
print(f"Q值输出形状: {q_values.shape}")
return q_values
2.4 经验回放与批采样 (Experience Replay & Batch Sampling)
这是DQN稳定训练的关键机制,类比于Swin Transformer中的“窗口划分”,它从大量交互经验中“划分”出小批量样本用于学习。
class ReplayBuffer:
经验回放缓冲区。
功能: 存储智能体的交互经验 (s, a, r, s', done),并随机采样小批量以打破数据相关性。
def __init__(self, capacity=100000):
self.capacity = capacity
self.buffer = []
self.position = 0
def push(self, state, action, reward, next_state, done):
"""存入一条经验"""
if len(self.buffer) < self.capacity:
self.buffer.append(None)
self.buffer[self.position] = (state, action, reward, next_state, done)
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
"""
随机采样一个小批量的经验。
Args:
batch_size: 批次大小
Returns:
批量的 (states, actions, rewards, next_states, dones)
"""
indices = np.random.choice(len(self.buffer), batch_size, replace=False)
states, actions, rewards, next_states, dones = zip(*[self.buffer[idx] for idx in indices])
# 转换为PyTorch张量
return (torch.stack(states),
torch.LongTensor(actions),
torch.FloatTensor(rewards),
torch.stack(next_states),
torch.FloatTensor(dones))
def __len__(self):
return len(self.buffer)
2.5 目标网络机制 (Target Network Mechanism)
类比于Swin Transformer中的“Shifted Window”带来的跨窗口信息交互,目标网络通过提供相对稳定的目标Q值,解决了自举法中的训练不稳定问题。
class DQNAgent:
DQN智能体,整合在线网络、目标网络及学习逻辑。
def __init__(self, state_channels, n_actions, learning_rate=1e-4, gamma=0.99, tau=0.005):
self.gamma = gamma # 未来奖励的折扣因子
self.tau = tau # 目标网络软更新参数
# 在线网络 (Online Network): 用于选择动作和计算当前Q值
self.online_backbone = DQNBackbone(input_channels=state_channels)
self.online_head = DQNValueHead(feature_dim=self.online_backbone._feature_dim, n_actions=n_actions)
self.online_network = nn.Sequential(self.online_backbone, self.online_head)
self.optimizer = torch.optim.Adam(self.online_network.parameters(), lr=learning_rate)
# 目标网络 (Target Network): 用于计算目标Q值,参数定期从在线网络同步
self.target_backbone = DQNBackbone(input_channels=state_channels)
self.target_head = DQNValueHead(feature_dim=self.target_backbone._feature_dim, n_actions=n_actions)
self.target_network = nn.Sequential(self.target_backbone, self.target_head)
# 初始化时,目标网络参数与在线网络相同
self._hard_update_target_network()
def _hard_update_target_network(self):
"""硬更新:将在线网络的参数完全复制给目标网络"""
self.target_network.load_state_dict(self.online_network.state_dict())
print("目标网络参数已硬更新。")
def _soft_update_target_network(self):
"""软更新:缓慢地将在线网络的参数混合到目标网络 (Polyak平均)"""
for target_param, online_param in zip(self.target_network.parameters(), self.online_network.parameters()):
target_param.data.copy_(self.tau * online_param.data + (1.0 - self.tau) * target_param.data)
def compute_loss(self, batch):
"""
计算一个批次的损失 (Huber Loss / Smooth L1 Loss)。
Args:
batch: 来自ReplayBuffer的样本元组 (s, a, r, s', done)
Returns:
loss: 计算出的损失标量
"""
states, actions, rewards, next_states, dones = batch
# 1. 使用在线网络计算当前状态-动作对的Q值
current_q_values = self.online_network(states).gather(1, actions.unsqueeze(1)) # shape: (batch, 1)
print(f"当前Q值形状: {current_q_values.shape}")
# 2. 使用目标网络计算下一状态的最大Q值 (Double DQN中会使用在线网络选择动作)
with torch.no_grad(): # 目标值计算不应产生梯度
next_q_values = self.target_network(next_states).max(1)[0] # shape: (batch,)
target_q_values = rewards + (self.gamma * next_q_values * (1 - dones))
print(f"目标Q值形状: {target_q_values.unsqueeze(1).shape}")
# 3. 计算Huber损失 (对离群值比MSE更鲁棒)
loss = nn.functional.smooth_l1_loss(current_q_values, target_q_values.unsqueeze(1))
print(f"计算得到的损失: {loss.item():.4f}")
return loss
def select_action(self, state, epsilon):
"""
使用 ε-greedy 策略选择动作。
Args:
state: 当前状态
epsilon: 探索概率
Returns:
action: 选择的动作索引
"""
if np.random.rand() < epsilon:
# 探索:随机选择动作
return np.random.randint(self.online_head.fc_layers[-1].out_features)
else:
# 利用:选择当前估计Q值最高的动作
with torch.no_grad():
q_values = self.online_network(state)
print(f"选择动作时Q值形状: {q_values.shape}")
return q_values.argmax().item()
2.6 分层训练流程 (Staged Training Pipeline)
典型的DQN训练流程包含多个层次分明的阶段,类似于Swin Transformer的各个Stage。
def dqn_training_pipeline(env, agent, total_episodes=1000):
"""
DQN分层训练流程。
"""
replay_buffer = ReplayBuffer(capacity=10000)
epsilon_start, epsilon_end, epsilon_decay = 1.0, 0.01, 0.995
epsilon = epsilon_start
for episode in range(total_episodes):
state = env.reset()
state_processor = StateProcessor()
state = state_processor.process_state(state)
total_reward = 0
while True:
# Stage 1: 交互与数据收集
action = agent.select_action(state, epsilon)
next_observation, reward, done, _ = env.step(action)
next_state = state_processor.process_state(next_observation)
# 将经验存入缓冲区
replay_buffer.push(state.squeeze(0), action, reward, next_state.squeeze(0), done)
state = next_state
total_reward += reward
# Stage 2: 网络更新 (当缓冲区足够大时)
if len(replay_buffer) > 64: # 类比于Swin Transformer中执行后续Block
batch = replay_buffer.sample(batch_size=32)
loss = agent.compute_loss(batch)
agent.optimizer.zero_grad()
loss.backward()
# 可选:梯度裁剪,防止梯度爆炸
torch.nn.utils.clip_grad_norm_(agent.online_network.parameters(), max_norm=10.0)
agent.optimizer.step()
# Stage 3: 软更新目标网络
agent._soft_update_target_network()
if done:
break
# 衰减探索率
epsilon = max(epsilon_end, epsilon * epsilon_decay)
print(f"回合 {episode+1}: 总奖励 = {total_reward:.2f}, 探索率 = {epsilon:.3f}")
print("训练完成!")
2.7 输出与应用 (Output & Application)
训练好的DQN模型可以用于决策,其输出是对每个动作的价值评估。
class TrainedDQN:
"""
训练完成的DQN模型,用于实际决策。
"""
def __init__(self, model_path, n_actions):
checkpoint = torch.load(model_path)
self.model = DQNAgent(state_channels=4, n_actions=n_actions)
self.model.online_network.load_state_dict(checkpoint['online_network_state_dict'])
self.model.online_network.eval() # 设置为评估模式
def predict(self, state):
"""
给定状态,输出所有动作的Q值。
Args:
state: 预处理后的状态张量 (1, 4, 84, 84)
Returns:
q_values: 所有动作的Q值 (1, n_actions)
best_action: 最佳动作索引
"""
with torch.no_grad():
q_values = self.model.online_network(state)
best_action = q_values.argmax().item()
return q_values.numpy(), best_action
def visualize_decision(self, state):
"""
可视化网络对当前状态的决策依据 (可选)。
可以通过类激活映射等方法,理解网络关注图像的哪部分。
"""
# 此处为高级功能示意,实现略。
pass
结构对比总结
| Swin Transformer (视觉) | DQN (强化学习) | 核心目的类比 |
|---|---|---|
| Patch Partition | Frame Stacking | 将原始高维输入(图像)转换为模型可处理的序列化/结构化输入。 |
| Window Partition | Experience Sampling | 将全局信息/经验池划分为更小的、可独立/并行处理的局部单元(窗口/批次)。 |
| W-MSA | Online Q-value Estimation | 在局部/当前范围内进行信息聚合与计算(窗口内自注意力 / 计算当前Q值)。 |
| SW-MSA | Target Q-value Calculation | 引入跨局部/未来的信息交互与约束(跨窗口连接 / 基于未来奖励的目标Q值),防止模型陷入局部最优或短视。 |
| Patch Merging | Feature Downsampling in Conv | 进行下采样,聚合多尺度/多层次的语义特征,增大感受野。 |
| Hierarchical Stages | Training Episodes & Updates | 分阶段/分层次地处理信息和学习,逐步优化模型表示或策略。 |
通过以上结构分解,DQN从原始像素输入到最终动作输出的完整数据处理与学习流程,被清晰地划分为多个模块化的步骤,每个步骤都有其明确的数学和工程目的。这种模块化视角有助于深入理解深度强化学习模型的设计思想。

浙公网安备 33010602011771号