cosmos policy code walk

cosmos policy code walk
简介
对于不知道什么是 World Action Model(WAM) 的读者,建议先阅读一下 这篇博客,这篇博客非常清晰地介绍了什么是 WAM,和 VLA 的区别等。
这篇文章分析一下 cosmos policy 的代码实现,参考 官方代码仓库,当前 commit 版本是 18a2accadf4e7a3531e56754102af5a24d2316da。
部署
仓库提供了 Dockerfile,可以直接部署,我在 HKUST 的 SuperPod 上使用,需要改写为 Apptainer,下面放出我改写后的文件。
构建好镜像后正常运行,可以按照指导说的那样进入交互式模式简单体验一下,这里不再赘述。
代码结构
cosmos-policy/
├── README.md / SETUP.md / LIBERO.md / ROBOCASA.md / ALOHA.md # 文档入口
├── docker/ # 可忽略
├── bin/ # 可忽略
└── cosmos_policy/ ★ 主战场
├── models/ ★★★ 核心模型
├── modules/ ★★ SDE + Sampler
├── datasets/ ★★ 数据管线
├── experiments/robot/ ★★ 推理 & 评测
├── config/ ★ 实验配置
├── scripts/train.py ★ 训练入口
├── trainer.py ★ 训练循环扩展
├── conditioner.py ★ 条件对象
├── constants.py ★ 平台常量
├── tokenizers/wan2pt1.py ☆ policy tokenizer 封装
├── utils/ ☆ 工具
└── _src/ ✗ 三个轮子,整体跳过
├── imaginaire/ (~200 files) 训练框架
├── predict2/ (~200 files) 视频生成模型
└── reason1/ (~50 files) VLM 嵌入(Policy 主要用 T5)
所有的代码都放在 cosmos-policy 目录下,我们后面讲的时候默认 Workspace 在这里。
_src 目录下存放了这里用到的三个轮子,分别是 nvidia 自己造的 imaginaire、predict2 和 reason1,这三个库的用法这里不讲。
models 目录下两个文件非常重要,定义了模型的结构,稍后我们重点讲。
modules 下两个文件分别定义了噪声采样器 HybridEDMSDE 和去噪器 CosmosPolicySampler。
experiments/robot 下存放实验相关代码,cosmos_utils.py 文件中实现了 get_acion() 函数,非常重要。
数据部分
重点看一下 datasets/libero_datasets.py 这个文件。
加载数据
从 171行到234行这一段是在加载演示数据,将每个 episode 存入内存,格式如下:
self.data[42] = {
"images": np.ndarray, # (150/T, 128/H, 128/W, 3) uint8
"wrist_images": np.ndarray, # (150/T, 128/H, 128/W, 3) uint8
"proprio": np.ndarray, # (150/T, 9) float32,z-score 归一化
"actions": np.ndarray, # (150/T, 7) float32,z-score 归一化
"command": "put both the alphabet soup and the tomato sauce in the basket",
"num_steps": 150,
"suite": "libero_10_no_noops_rerendered",
"returns": np.ndarray, # (150/T,) float32,demo 末尾 reward=1.0 折扣得到
}
指令从文件名中解析,之后计算价值,termial reward = 1.0,用传入的常数 gamma 做 Monte Carlo 折扣,之后缩放到 -1,1 之间。
之后跳到 401-414 行,这里构建了全局 step 索引,按 task suite 进行分组,进行均衡采样。
__getitem__ 实现
__getitem__ 根据全局 step 索引 idx 区分三种数据来源:
| 类型 | 定位方式 | 数据加载 |
|---|---|---|
| demo | idx % num_steps → _step_to_episode_map | self.data(已在内存) |
| success rollout | idx - adjusted_demo_count → success map | 懒加载 HDF5 |
| failure rollout | idx - demo - success → failure map | 懒加载 HDF5 |
另外 return_value_function_returns=True 且采样方式非 demo 时,50% 概率标为 world_model_sample,50% 标为 value_function_sample。
接下来是最重要的拼 image_list。
LIBERO 默认配置下的 latent 序列布局(num_duplicates_per_image=4,即每类占 4 个 RGB 帧 = 1 个 latent 帧):
| current_sequence_idx | 内容 | 实际图像 | latent_idx 变量 |
|---|---|---|---|
| 0 | blank (1帧) | 全零 | (VAE 时序压缩占位) |
| 1 | current proprio | 全零×4 | current_proprio_latent_idx |
| 2 | current wrist | 真实腕部×4 | current_wrist_image_latent_idx |
| 3 | current primary | 真实第三人称×4 | current_image_latent_idx |
| 4 | action chunk | 全零×4 | action_latent_idx |
| 5 | future proprio | 全零×4 | future_proprio_latent_idx |
| 6 | future wrist | 未来腕部×4 | future_wrist_image_latent_idx |
| 7 | future primary | 未来第三人称×4 | future_image_latent_idx |
| 8 | value | 全零×4 | value_latent_idx |
这里还对应了 config 里的 state_t=9。
之后填充 Action Chunk,一个action_chunk 表示未来 chunk_size 步的动作,next_action_chunk 表示从 t+chunk_size 起的下一个动作,末尾不够则一直重复最后一个动作。
最后返回填充完整的字典。
上面有很多全零的部分,这里会在下一部分“核心模型”处填充对应内容。
核心模型
先讲 text2world_model.py,因为 video2world_model.py 是继承自这个文件的。
后文张量维度统一记为 B × C × T × H × W:B 为 batch size,C 为 latent 通道数,T 为时间帧数,H 与 W 分别为空间高宽。
首先有两个函数,分别是 replace_latent_with_action_chunk() 和 replace_latent_with_proprio。这两个函数的作用是填充上面还是全零的内容。
另外是 CosmosPolicyModelConfig 这个类,这个类的定义和 modules/hybrid_edm_sde.py 中的 HybridEDMSDE 是息息相关的,大概是父类采样是纯 log-normal 的方式,这里采用了 70% 纯 log-normal,30% uniform的方式,注意这里的 70% 和 30% 是期望。另外就是新增的五个字段,作用如下:
| 字段 | 默认 | 作用 |
|---|---|---|
mask_loss_for_action_future_state_prediction | False | 按样本类型分流 loss:demo→action,rollout→future state 或 value |
mask_value_prediction_loss_for_policy_prediction | False | demo 同时优化 action+future state,不算 value loss |
mask_current_state_action_for_value_prediction | False | value 样本只优化 value 帧(屏蔽 current state/action 帧的 loss) |
mask_future_state_for_qvalue_prediction | False | Q(s,a) 预测时屏蔽 future state 帧 |
action_loss_multiplier | 1 | action 帧 loss 权重倍数(整数) |
CosmosPolicyDiffusionModel 类
这个是最重要的一段代码,除了 __init__ 之外一共三个函数,但有些比较重要的函数在父类中实现,这里也讲一下作用。
训练流程在 training_step 中,首先更新一下训练统计的信息,之后选择性地编码文本。
之后先获取干净的潜变量和条件,然后采样出噪声水平和噪声,再进行一次广播操作将数据进行切分方便并行计算。
然后进行损失聚合,分为两种 mean 是对所有元素求平均,sum 是沿着通道维度求和并求平均,之后乘以缩放因子。
上面有个关键函数 compute_loss_with_epsilon_and_sigma,下面讲这个。
这里传入了大量的数据,这些数据看名字就能看出作用,不难理解。
输入的原始 x_0 还是纯净的,没有注入 action/proprio/value 这些,首先进行备份然后注入,方法就是上面两个函数的展平重复填满,之后加噪得到 xt,之后去噪得到预测的 model_pred,然后对所有的元素都计算 loss,之后根据上面预设的 mask 方式控制每个 batch 样本的每个 latent 帧是否参与反传。
generate_samples_from_batch 为推理的主入口,方式大概为将 denoise 函数包装到 x0_fn 中,然后传递给 sampler 调用,之后获取去噪的 latent 结果。
CosmosPolicyVideo2WorldModel 类
这里重写了 get_data_and_condition,get_x0_fn_from_batch,denoise 和 draw_training_sigma_and_epsilon 三个函数,另外 Text2WorldModelConfig 类也增加了 Video2WorldConfig 类的一些字段。
LIBERO 实验覆盖:min_num_conditional_frames=4,sigma_conditional=0.0,state_t=9。
get_data_and_condition
这里将前 4 帧设置成了条件帧,同时按照 rollout 任务改 mask, world model模式 action 也设置成条件帧,value模式除了value全部是条件帧数,Q-value 模式下 mask 掉 future state。
此外还要往 gt_frames 中也注入非图像模式。
denoise 中设置了条件帧不去噪,只去噪后面需要去噪的几帧。