SB3库中向量化环境vecenv是单环境时 action被错误调整维度导致报错/擅自取出obs导致观测值不是向量化
代码简介
`
action维度维(3,)
self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32)
单向量化环境
env = Env(render_mode="human" if render else None)
env = DummyVecEnv([lambda: env])
env = VecNormalize.load(env_stats_path, env)
predict 得到的 action = [-0.9983589 0.9999726 0.9999852] ,shape:(3,)
action, _ = model.predict(obs, deterministic=deterministic)
执行过程
obs, reward, done, info = env.step(action)
进入 D:/ProgramData/anaconda3/envs/sb3/Lib/site-packages/stable_baselines3/common/vec_env/vec_normalize.py:181 obs, rewards, dones, infos = self.venv.step_wait()
进入 D:/ProgramData/anaconda3/envs/sb3/Lib/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py:55
def step_wait(self) -> VecEnvStepReturn:
# Avoid circular imports
for env_idx in range(self.num_envs):
obs, self.buf_rews[env_idx], terminated, truncated, self.buf_infos[env_idx] = self.envs[env_idx].step(
self.actions[env_idx]
)
出错:self.actions[env_idx]是一个float,不是数组
进入 自定义的step函数,得到的action不再是shape=(3,)
`
正常情况
在 Stable-Baselines3 的 VecEnv 实现中,当您将一个 action 传递给 env.step() 时,如果这是一个单环境(DummyVecEnv 包装了单个环境),action 会被自动扩展一个批次维度。
也就是说,如果您传入的 action 是 (3,),它在 VecEnv 内部可能会被处理成 (1, 3)。当 DummyVecEnv 内部迭代并尝试取出 self.actions[env_idx] 时,它取出的仍然是 (3,) 维度的 action。
实际情况
action在 VecEnv 内部确实被处理成 (1, 3)了,但是由于 vectorized_env=False,action维度又变成(3,)了
D:/ProgramData/anaconda3/envs/sb3/Lib/site-packages/stable_baselines3/common/policies.py:382
if not vectorized_env:
assert isinstance(actions, np.ndarray)
actions = actions.squeeze(axis=0)
问题定位
D:/ProgramData/anaconda3/envs/sb3/Lib/site-packages/stable_baselines3/common/policies.py:365
obs_tensor, vectorized_env = self.obs_to_tensor(observation)
vectorized_env一直为True,直到被这行代码赋值为False
解决办法:确保使用向量化环境时;vectorized_env=True
obs_tensor, vectorized_env = self.obs_to_tensor(observation) 把vectorized_env 赋值为False,
是因为observation不是向量化的,回到自己的代码发现:obs = env.reset()[0]
修改为obs = env.reset()即可
结论
对 Stable-Baselines3 中 VecEnv(向量化环境)和 Policy(策略)之间的数据交互约定理解不足。
在使用复杂框架时,务必深入理解其核心数据结构(例如张量/数组的形状和维度)、模块之间的输入输出接口规范。
浙公网安备 33010602011771号