cosmos policy code walk

cosmos policy 代码阅读

cosmos policy code walk

9 May 2026

4 min read

简介

对于不知道什么是 World Action Model(WAM) 的读者,建议先阅读一下 这篇博客,这篇博客非常清晰地介绍了什么是 WAM,和 VLA 的区别等。

这篇文章分析一下 cosmos policy 的代码实现,参考 官方代码仓库,当前 commit 版本是 18a2accadf4e7a3531e56754102af5a24d2316da。

部署

仓库提供了 Dockerfile,可以直接部署,我在 HKUST 的 SuperPod 上使用,需要改写为 Apptainer,下面放出我改写后的文件。

构建好镜像后正常运行,可以按照指导说的那样进入交互式模式简单体验一下,这里不再赘述。

代码结构

TXT
cosmos-policy/
├── README.md / SETUP.md / LIBERO.md / ROBOCASA.md / ALOHA.md  # 文档入口
├── docker/                                                     # 可忽略
├── bin/                                                        # 可忽略
└── cosmos_policy/                    ★ 主战场
    ├── models/                       ★★★ 核心模型
    ├── modules/                      ★★  SDE + Sampler
    ├── datasets/                     ★★  数据管线
    ├── experiments/robot/            ★★  推理 & 评测
    ├── config/                       ★   实验配置
    ├── scripts/train.py              ★   训练入口
    ├── trainer.py                    ★   训练循环扩展
    ├── conditioner.py                ★   条件对象
    ├── constants.py                  ★   平台常量
    ├── tokenizers/wan2pt1.py         ☆   policy tokenizer 封装
    ├── utils/                        ☆   工具
    └── _src/                         ✗  三个轮子,整体跳过
        ├── imaginaire/   (~200 files)  训练框架
        ├── predict2/     (~200 files)  视频生成模型
        └── reason1/      (~50 files)   VLM 嵌入(Policy 主要用 T5)

所有的代码都放在 cosmos-policy 目录下,我们后面讲的时候默认 Workspace 在这里。

_src 目录下存放了这里用到的三个轮子,分别是 nvidia 自己造的 imaginairepredict2reason1,这三个库的用法这里不讲。

models 目录下两个文件非常重要,定义了模型的结构,稍后我们重点讲。

modules 下两个文件分别定义了噪声采样器 HybridEDMSDE 和去噪器 CosmosPolicySampler

experiments/robot 下存放实验相关代码,cosmos_utils.py 文件中实现了 get_acion() 函数,非常重要。

数据部分

重点看一下 datasets/libero_datasets.py 这个文件。

加载数据

从 171行到234行这一段是在加载演示数据,将每个 episode 存入内存,格式如下:

Python
self.data[42] = {
    "images":        np.ndarray,  # (150/T, 128/H, 128/W, 3) uint8
    "wrist_images":  np.ndarray,  # (150/T, 128/H, 128/W, 3) uint8
    "proprio":       np.ndarray,  # (150/T, 9) float32,z-score 归一化
    "actions":       np.ndarray,  # (150/T, 7) float32,z-score 归一化
    "command":       "put both the alphabet soup and the tomato sauce in the basket",
    "num_steps":     150,
    "suite":         "libero_10_no_noops_rerendered",
    "returns":       np.ndarray,  # (150/T,) float32,demo 末尾 reward=1.0 折扣得到
}

指令从文件名中解析,之后计算价值,termial reward = 1.0,用传入的常数 gamma 做 Monte Carlo 折扣,之后缩放到 -1,1 之间。

之后跳到 401-414 行,这里构建了全局 step 索引,按 task suite 进行分组,进行均衡采样。

__getitem__ 实现

__getitem__ 根据全局 step 索引 idx 区分三种数据来源:

类型定位方式数据加载
demoidx % num_steps_step_to_episode_mapself.data(已在内存)
success rolloutidx - adjusted_demo_count → success map懒加载 HDF5
failure rolloutidx - demo - success → failure map懒加载 HDF5

另外 return_value_function_returns=True 且采样方式非 demo 时,50% 概率标为 world_model_sample,50% 标为 value_function_sample。

接下来是最重要的拼 image_list。

LIBERO 默认配置下的 latent 序列布局(num_duplicates_per_image=4,即每类占 4 个 RGB 帧 = 1 个 latent 帧):

current_sequence_idx内容实际图像latent_idx 变量
0blank (1帧)全零(VAE 时序压缩占位)
1current proprio全零×4current_proprio_latent_idx
2current wrist真实腕部×4current_wrist_image_latent_idx
3current primary真实第三人称×4current_image_latent_idx
4action chunk全零×4action_latent_idx
5future proprio全零×4future_proprio_latent_idx
6future wrist未来腕部×4future_wrist_image_latent_idx
7future primary未来第三人称×4future_image_latent_idx
8value全零×4value_latent_idx

这里还对应了 config 里的 state_t=9。

之后填充 Action Chunk,一个action_chunk 表示未来 chunk_size 步的动作,next_action_chunk 表示从 t+chunk_size 起的下一个动作,末尾不够则一直重复最后一个动作。

最后返回填充完整的字典。

上面有很多全零的部分,这里会在下一部分“核心模型”处填充对应内容。

核心模型

先讲 text2world_model.py,因为 video2world_model.py 是继承自这个文件的。

后文张量维度统一记为 B × C × T × H × WB 为 batch size,C 为 latent 通道数,T 为时间帧数,HW 分别为空间高宽。

首先有两个函数,分别是 replace_latent_with_action_chunk()replace_latent_with_proprio。这两个函数的作用是填充上面还是全零的内容。

另外是 CosmosPolicyModelConfig 这个类,这个类的定义和 modules/hybrid_edm_sde.py 中的 HybridEDMSDE 是息息相关的,大概是父类采样是纯 log-normal 的方式,这里采用了 70% 纯 log-normal,30% uniform的方式,注意这里的 70% 和 30% 是期望。另外就是新增的五个字段,作用如下:

字段默认作用
mask_loss_for_action_future_state_predictionFalse按样本类型分流 loss:demo→action,rollout→future state 或 value
mask_value_prediction_loss_for_policy_predictionFalsedemo 同时优化 action+future state,不算 value loss
mask_current_state_action_for_value_predictionFalsevalue 样本只优化 value 帧(屏蔽 current state/action 帧的 loss)
mask_future_state_for_qvalue_predictionFalseQ(s,a) 预测时屏蔽 future state 帧
action_loss_multiplier1action 帧 loss 权重倍数(整数)

CosmosPolicyDiffusionModel 类

这个是最重要的一段代码,除了 __init__ 之外一共三个函数,但有些比较重要的函数在父类中实现,这里也讲一下作用。

训练流程在 training_step 中,首先更新一下训练统计的信息,之后选择性地编码文本。

之后先获取干净的潜变量和条件,然后采样出噪声水平和噪声,再进行一次广播操作将数据进行切分方便并行计算。

然后进行损失聚合,分为两种 mean 是对所有元素求平均,sum 是沿着通道维度求和并求平均,之后乘以缩放因子。

上面有个关键函数 compute_loss_with_epsilon_and_sigma,下面讲这个。

这里传入了大量的数据,这些数据看名字就能看出作用,不难理解。

输入的原始 x_0 还是纯净的,没有注入 action/proprio/value 这些,首先进行备份然后注入,方法就是上面两个函数的展平重复填满,之后加噪得到 xt,之后去噪得到预测的 model_pred,然后对所有的元素都计算 loss,之后根据上面预设的 mask 方式控制每个 batch 样本的每个 latent 帧是否参与反传。

generate_samples_from_batch 为推理的主入口,方式大概为将 denoise 函数包装到 x0_fn 中,然后传递给 sampler 调用,之后获取去噪的 latent 结果。

CosmosPolicyVideo2WorldModel 类

这里重写了 get_data_and_condition,get_x0_fn_from_batch,denoise 和 draw_training_sigma_and_epsilon 三个函数,另外 Text2WorldModelConfig 类也增加了 Video2WorldConfig 类的一些字段。

LIBERO 实验覆盖:min_num_conditional_frames=4,sigma_conditional=0.0,state_t=9。

get_data_and_condition

这里将前 4 帧设置成了条件帧,同时按照 rollout 任务改 mask, world model模式 action 也设置成条件帧,value模式除了value全部是条件帧数,Q-value 模式下 mask 掉 future state。

此外还要往 gt_frames 中也注入非图像模式。

denoise 中设置了条件帧不去噪,只去噪后面需要去噪的几帧。

配置与训练入口

推理与评测