【学习】ACT项目的复现与理解

一、 环境配置

conda create -n aloha python=3.10.12
conda activate aloha
pip instalohal torchvision
pip install typeguard
pip install torch
pip install pyquaternion
pip install pyyaml
pip install rospkg
pip install pexpect

# ===== 如果需要升级mujoco需要搭配升级dm_control ===== 
pip install mujoco==2.3.7
pip install dm_control==1.0.14

或

pip install mujoco==3.2.5
pip install dm_control==1.0.25
# ===== 如果需要升级mujoco需要搭配升级dm_control ===== 

pip install opencv-python
pip install matplotlib
pip install einops
pip install packaging
pip install h5py
pip install ipython
cd act/detr && pip install -e .

二、 运行

我认为绝大多数人都没有aloha吧,并且很多人能操作的是一组机械臂,无法两组机械臂做采集和示范,因此后面的内容主要倾向于单组机械臂,使用仿真数据进行训练。

2.1 生成训练用轨迹

命令:

python3 record_sim_episodes.py \
--task_name sim_transfer_cube_scripted \
--dataset_dir ./data_sim_episodes/sim_transfer_cube_scripted \
--num_episodes 50

参数说明:

  • --task_name
    • 任务名称,act项目中预设了两类四种任务
      • 脚本生成
        • sim_transfer_cube_scripted,传递方块
        • sim_insertion_scripted,双臂插入
      • 人类示教
        • sim_transfer_cube_human,传递方块
        • sim_insertion_human,双臂插入
  • --dataset_dir
  • --num_episodes
    • 要生成的演示数据数量(episode数量)
    • 1个episode指的是一个任务从初始状态到结束状态
    • 训练时项目自动分割数据,8:2
  • --onscreen_render
    • 是否在屏幕上实时显示仿真画面
    • 默认false,需要显示指定
    • 如果远程了服务器,无法开启这个参数

1

2.2 可视化episode数据

多数情况下,大家会通过无桌面的服务器跑项目,那就需要这个命令观察自己做的仿真数据是否符合预期。

命令:

python3 visualize_episodes.py \
    --dataset_dir ./data_sim_episodes/sim_transfer_cube_scripted \
    --episode_idx 0 

参数说明

  • dataset_dir
    • episode存放路径
  • episode_idx
    • 待可视化的episode序号

2

2.3 训练策略

修改constants.py,配置episode读取路径

示例:/home/zhaoshuai/workspace_act/act/data_sim_episodes

这里写到父目录即可,下一级目录在SIM_TASK_CONFIGS中dataset_dir中定义训练

3

4

训练

python3 imitate_episodes.py \
    --task_name sim_transfer_cube_scripted \
    --ckpt_dir ./ckpts/sim_transfer_cube \
    --policy_class ACT \
    --kl_weight 10 \
    --chunk_size 100 \
    --hidden_dim 512 \
    --batch_size 8 \
    --dim_feedforward 3200 \
    --num_epochs 2000 \
    --lr 1e-5 \
    --seed 0


python3 /home/zhaoshuai/workspace_act/act/imitate_episodes.py --task_name sim_lifting_cube_scripted --ckpt_dir ./ckpts/sim_lifting_cube_scripted --policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 512 --batch_size 8 --dim_feedforward 3200 --num_epochs 2000 --lr 1e-5 --seed 0 --equipment_model fairino5_single

参数说明

  • --eval
    • 是否开启评估模式,默认false
  • --ckpt_dir
    • 训练时模型保存路径
  • --policy_class
    • 主要策略,这里项目中虽然带此参数,但实际上没有写对应逻辑,只能选ACT和CNNMLP(还不完善)
  • --onscreen_render
    • 是否在屏幕上实时显示仿真画面
    • 默认false
  • --task_name
    • 任务类型,同2.1
  • --batch_size
    • 单次传递样本个数
  • --kl_weight
    • KL散度权重
    • 控制模型发散能力,kl_weight越小,模型发散能力越强,动作越多样
    • 控制模型发散能力,kl_weight越大,模型发散能力越弱,动作越固定
    • CVAE条件变分自编码器
      • total_loss = L1_loss + kl_weight × KL_loss
      • L1_loss:预测动作和真实动作的差异
      • KL_loss:学到的分布和标准正态分布的差异
  • --chunk_size
    • 每次预测的未来步数
    • 数值越大越适合长期任务,显存占用越大,推理速度越慢
  • --hidden_dim
    • transformer中d_model、Query、Key、Value、特征向量的维度
    • 一般情况下,transformer的参数量=4hidden_dimhidden_dim
    • hidden_dim越大(2的n次方取值),显存占用翻倍提升,训练速度快速下降,性能小幅提升
  • --dim_feedforward
    • transformer中每层前馈网络 (Feed-Forward Network, FFN)维度
    • hidden_dim可以类比成长期记忆,记录最有用的
    • dim_feedforward类比成短期草稿,实时进行计算,越大展开越详细速度越慢越占用显存
  • --num_epochs
    • 训练轮次,所有训练数据都过一遍算1轮
  • --lr
    • 学习率,在梯度下降过程中更新权重时的超参数
    • 一般根据经验设置在0.01到0.001之间
    • 学习率越小损失函数变化速度就越慢,越容易过拟合
    • 学习率越大损失函数振动幅度就越大,模型难以收敛,容易发生梯度爆炸
  • --seed
    • 随机种子

5

2.4 评估策略

在训练命令的基础上添加--eval即可。

python3 imitate_episodes.py \
    --task_name sim_transfer_cube_scripted \
    --ckpt_dir ./ckpts/sim_transfer_cube \
    --policy_class ACT \
    --kl_weight 10 \
    --chunk_size 100 \
    --hidden_dim 512 \
    --batch_size 8 \
    --dim_feedforward 3200 \
    --num_epochs 2000 \
    --lr 1e-5 \
    --seed 0 \
    --eval

这里程序会

  • 加载训练好的最佳模型(policy_best. ckpt)
  • 在仿真环境中运行 50 次独立测试
  • 记录每次的成功/失败
  • 计算整体成功率
  • 为每次测试生成视频

6

7

三、对项目的理解

3.1 轨迹生成部分

由于使用仿真生成的数据进行训练,那么在自定义任务时,势必需要理解并改动此处,因此记录一下自己的心得。

对于整个项目来说,其所需的训练数据(等同于轨迹生成数据)的数据格式是hdf5.

在make_ee_sim_env中阐述了其数据组成结构(以双臂为例)

8

整体包含观测数据 Observation 和action数据

  • Observation
    • qpos
      • (T, 14) float
      • 左臂关节6 + 左臂夹爪1 + 右臂关节6 + 右臂夹爪1
      • 关节关于初始角度的绝对角度
    • qvel
      • (T, 14) float
      • 左臂关节6 + 左臂夹爪1 + 右臂关节6 + 右臂夹爪1
      • 在当前timestep,关节的瞬时速度
    • images/
      • cam0
      • (T, 480, 640, 3) uint8
      • 3路 rgb 480*640图像
  • action
    • 左臂关节6 + 左臂夹爪1 + 右臂关节6 + 右臂夹爪1
    • 关节关于初始角度的绝对角度

调用关系

record_sim_episodes.py (主程序)
    │
    ├─→ scripted_policy.py
    │   ├─ PickAndTransferPolicy.__call__()
    │   │   └─ 返回:action (16维末端坐标)
    │   │       [left_x, left_y, left_z, left_qw, left_qx, left_qy, left_qz, left_gripper,
    │   │        right_x, right_y, right_z, right_qw, right_qx, right_qy, right_qz, right_gripper]
    │   │
    │   └─ generate_trajectory()
    │       └─ 定义路点轨迹
    │
    ├─→ ee_sim_env.py
    │   ├─ make_ee_sim_env()
    │   │   └─ 返回:Environment 对象
    │   │
    │   └─ BimanualViperXEETask
    │       ├─ before_step(action, physics)
    │       │   ├─ 输入:action (16维末端坐标)
    │       │   ├─ 处理:设置 mocap_pos 和 mocap_quat
    │       │   └─ 效果:mocap 拖动机械臂末端
    │       │
    │       └─ get_observation(physics)
    │           └─ 返回:包含 qpos (14维关节角度) 的观测
    │
    └─→ dm_control (MuJoCo 物理引擎)
        ├─ control.Environment
        │   └─ step(action)
        │       ├─ 调用:task.before_step(action, physics)
        │       ├─ 执行:physics.step() 物理仿真
        │       └─ 调用:task.get_observation(physics)
        │
        └─ mujoco.Physics
            ├─ data.mocap_pos  ← before_step 写入末端坐标
            ├─ data.mocap_quat ← before_step 写入末端姿态
            ├─ step()          ← 物理仿真推进
            │   └─ 内部:mocap 约束 → 求解逆运动学 → 更新关节角度
            └─ data.qpos       ← step 后更新的关节角度

大致流程

  • scripted_policy输出mocap的[x, y, z, quat, gripper];
  • ee_sim_env 接收后,转换为关节控制并推进仿真、记录qpos;
  • 将ee_sim_env中记录的qpos作为target输入给sim_env;
  • sim_env记录qpos、qvel;
  • 最终将ee_sim_env记录的qpos作为action,结合sim_env的qpos、qvel、images存入hdf5。
  • 记录hdf5
    • /observations/qpos :sim_env 的 ts.observation['qpos']
    • /observations/qvel :sim_env 的 ts.observation['qvel']
    • /observations/images/<cam> :sim_env 的 ts.observation['images'][cam]
    • /action :ee_sim_env 的 ts.observation['qpos']

其中scripted_policy生成的action跟hdf5中的/action不一样,其内容为:

  • 执行器末端 xyz,3
  • 执行器末端四元数,4,
  • 夹爪开合状态gripper,1
  • 右侧同理

整体调用细节如下:

  • make_ee_sim_env,ee_sim_env.py
    • 初始化mujoco、dm_control等环境
  • policy_cls(inject_noise),scripted_policy.py,BasePolicy,init
    • 初始化policy
  • policy(ts),scripted_policy.py,BasePolicy,call
    • 初次调用时,一次性生成完整的400个timestep,仿真阶段是开环执行,轨迹是一次性生成的,不再实时调整
  • self.generate_trajectory(ts),对应任务下重写的generate_trajectory,构建左右双手的关键帧(关键路点)

9

  • 遇到关键路点时做记录
    • curr_left_waypoint记录当前所在关键路点
    • next_left_waypoint记录下一个关键路点

10

  • class BasePolicy,interpolate,插值,得到当前时刻t的point
    • 根据t在curr_t和next_t的范围,计算point,具体如上图
    • 最终拼成action,包含28个数据,2(xyz+quat+gripper),做成一个list
  • ts = env.step(action)
    • 做成dm_env.TimeStep
    • DeepMind Control Suite 的标准时间步结构
    • 有了所有ts,mujoco就可以重现仿真场景
  • env.step,before_step
    • 将action转成mjdata
    • mujoco的physics 是 mujoco.Physics 类的实例,是 MuJoCo 物理引擎的 Python 接口,包含了整个仿真世界的所有状态和配置。
    • 将action数据复制到physics.data.mocap_pos和physics.data.mocap_quat中
    • physics.data.mocap_pos[0]表示left,[1]表示right,根据xml中worldbody下 的body顺序定义。

11

  • physics.data.ctrl存储两只手四个夹爪开合距离(左夹爪左侧关节,左夹爪右侧关节,右夹爪左侧关节,右夹爪右侧关节)

12

  • 计算每个timestep的reward
    • 对于transfer_cube任务,其reward规则如下(ee_sim_env.py,get_reward方法)

13

  • 所有ts都存入episode
  • 转换成mujcoco qpos
  • 回放

mujoco坐标系

14

15

3.2 assets部分

  • bimanual_viperx_ee_transfer_cube.xml
    • 场景描述
    • transfer_cube任务的场景中有哪些内容
    • 如两个机械臂、桌子、被抓的箱子
  • scene.xml
    • 加载桌子、相机、光源
  • vx300s_dependencies.xml
    • 资源配置文件
    • 加载机械臂的各个stl并给出mesh name,方便vx300s_left.xml和right定义机械臂
  • bimanual_viperx_transfer_cube.xml
    • transfer_cube场景回放时的任务场景,初始位置与bimanual_viperx_ee_transfer_cube完全对齐

3.3 制作轨迹

16

轨迹是相对的,而不是绝对的。

这里通过赋值能够看出来,170步之前的xyz基于box做偏移,目的是接近物体并抓起来;

170以后是为了接近左侧机械臂,因此按照meet_xyz做偏移。

如果新增抓取放入左侧,只需要一个box_xyz,一个target_area_xyz

posted @ 2026-03-05 13:06  小拳头呀  阅读(0)  评论(0)    收藏  举报