TinyWorlds 源码阅读:1000 行代码理解 Genie 架构

为什么选 TinyWorlds

想理解世界模型的架构,有几个选择:

  1. 读 DeepMind 的论文——只有原理,没有代码
  2. 读 GenieRedux——代码完整但复杂,几千行
  3. 读 TinyWorlds——精简实现,1000 多行代码

TinyWorlds 是 GitHub 上最受欢迎的 Genie 复现项目,1.1k stars。它的价值不在于效果好,而在于代码简单,适合学习。

这篇文章带你过一遍核心代码。

项目结构

tinyworlds/
├── models/
│   ├── tokenizer.py      # 视频分词器
│   ├── lam.py            # 隐动作模型
│   └── dynamics.py       # 动力学模型
├── train.py              # 训练脚本
├── generate.py           # 推理脚本
└── data/                 # 数据处理

核心就三个模型文件,加起来不到 1000 行。

分词器(Tokenizer)

分词器的任务:把连续的画面变成离散的 token 序列。

打开 models/tokenizer.py,核心结构是 VQ-VAE:

class Tokenizer(nn.Module):
    def __init__(self, ...):
        self.encoder = Encoder(...)  # 把画面压缩成潜在向量
        self.decoder = Decoder(...)  # 从潜在向量重建画面
        self.codebook = nn.Embedding(num_codes, latent_dim)  # 码本

关键的量化过程:

def quantize(self, z):
    # z: [B, C, H, W] 编码器输出
    # 找到码本中最近的向量
    distances = torch.cdist(z.flatten(2).transpose(1,2), self.codebook.weight)
    indices = distances.argmin(dim=-1)  # 每个位置对应的码本索引
    z_q = self.codebook(indices)  # 查表得到量化后的向量
    return z_q, indices

这段代码做的事:把连续的特征向量"硬塞"到离散的码本里。每个码本条目就是一个 token。

为什么要离散化?因为后面的动力学模型是 Transformer,Transformer 处理离散 token 比连续向量更擅长。

隐动作模型(LAM)

当训练数据里没有动作标签时,LAM 用来推断"发生了什么动作"。

打开 models/lam.py

class LatentActionModel(nn.Module):
    def forward(self, frame_t, frame_t_plus_1):
        # 把前后两帧拼接
        x = torch.cat([frame_t, frame_t_plus_1], dim=1)
        # 通过编码器得到隐动作
        latent_action = self.encoder(x)
        return latent_action

思路很直接:看前后两帧的变化,推断发生了什么动作。

这个"隐动作"不是人类定义的动作(上下左右),而是模型学出来的一个向量。只要这个向量能帮助预测下一帧,模型就学对了。

动力学模型(Dynamics)

这是核心。给定当前帧的 token 和动作,预测下一帧的 token。

打开 models/dynamics.py

class DynamicsModel(nn.Module):
    def __init__(self, ...):
        self.transformer = TransformerEncoder(...)
        self.token_embed = nn.Embedding(num_codes, embed_dim)
        self.action_embed = nn.Linear(action_dim, embed_dim)
        self.output_head = nn.Linear(embed_dim, num_codes)

前向传播:

def forward(self, tokens, action):
    # tokens: [B, H*W] 当前帧的token序列
    # action: [B, action_dim] 动作向量
    
    x = self.token_embed(tokens)  # [B, H*W, D]
    a = self.action_embed(action) # [B, D]
    
    # 把动作加到序列开头
    x = torch.cat([a.unsqueeze(1), x], dim=1)  # [B, 1+H*W, D]
    
    # Transformer 编码
    x = self.transformer(x)
    
    # 预测每个位置的下一个token
    logits = self.output_head(x[:, 1:])  # [B, H*W, num_codes]
    return logits

关键点:

  1. Token 通过 embedding 层变成向量
  2. 动作也变成向量,拼在序列开头
  3. Transformer 处理整个序列
  4. 输出每个位置的预测分布

MaskGIT 训练方式

TinyWorlds 用 MaskGIT 的方式训练动力学模型。

训练时:

def train_step(self, tokens, action, target_tokens):
    # 随机遮住一部分输入token
    mask_ratio = random.uniform(0.5, 1.0)
    mask = torch.rand_like(tokens.float()) < mask_ratio
    masked_tokens = tokens.clone()
    masked_tokens[mask] = MASK_TOKEN
    
    # 预测被遮住的部分
    logits = self.forward(masked_tokens, action)
    loss = F.cross_entropy(logits[mask], target_tokens[mask])
    return loss

推理时:

def generate(self, tokens, action, steps=10):
    current = torch.full_like(tokens, MASK_TOKEN)
    
    for step in range(steps):
        logits = self.forward(current, action)
        probs = F.softmax(logits, dim=-1)
        
        # 每步揭示一部分token
        confidence = probs.max(dim=-1).values
        num_to_reveal = len(current) // steps
        reveal_indices = confidence.topk(num_to_reveal).indices
        
        current[reveal_indices] = probs[reveal_indices].argmax(dim=-1)
    
    return current

MaskGIT 的好处是可以并行预测多个 token,比自回归方式快很多。

训练流程

train.py 里的训练循环:

for batch in dataloader:
    frames, actions = batch  # frames: [B, T, H, W], actions: [B, T-1, D]
    
    # 1. 把帧转成token
    tokens = tokenizer.encode(frames)  # [B, T, h*w]
    
    # 2. 训练动力学模型
    for t in range(T-1):
        current_tokens = tokens[:, t]
        next_tokens = tokens[:, t+1]
        action = actions[:, t]
        
        loss = dynamics.train_step(current_tokens, action, next_tokens)
        loss.backward()
        optimizer.step()

简化后就是:遍历时间步,用当前帧和动作预测下一帧。

推理流程

generate.py 里的生成逻辑:

def generate_video(initial_frame, actions):
    # 编码初始帧
    current_tokens = tokenizer.encode(initial_frame)
    frames = [initial_frame]
    
    for action in actions:
        # 预测下一帧的token
        next_tokens = dynamics.generate(current_tokens, action)
        # 解码成画面
        next_frame = tokenizer.decode(next_tokens)
        
        frames.append(next_frame)
        current_tokens = next_tokens
    
    return frames

一帧一帧往后推。每一帧都基于前一帧和当前动作生成。

和论文的对应关系

TinyWorlds 实现的是 Genie 论文里的核心架构:

论文组件 TinyWorlds 文件 作用
Video Tokenizer tokenizer.py 离散化画面
Latent Action Model lam.py 推断隐动作
Dynamics Model dynamics.py 预测下一帧

论文里还有一些工程优化(比如更高效的 attention、更大的码本等)TinyWorlds 没有实现。这也是为什么效果差距大。

可以改的地方

如果你想基于 TinyWorlds 做实验,几个容易改的点:

增大码本

默认的码本可能只有几百条。增大到几千条,重建质量会提升。

# tokenizer.py
self.codebook = nn.Embedding(4096, latent_dim)  # 从512改成4096

增大 Transformer

默认的 Transformer 可能很小。加深加宽会提升预测能力。

# dynamics.py
self.transformer = TransformerEncoder(
    num_layers=12,  # 从6改成12
    dim=512,        # 从256改成512
    heads=8
)

更长的上下文

默认可能只看一帧历史。改成看多帧:

def forward(self, token_history, action):
    # token_history: [B, T, H*W] 过去T帧的token
    # 拼接成更长的序列

代码质量评价

TinyWorlds 的优点:

  • 结构清晰,每个组件职责明确
  • 注释足够理解意图
  • 没有过度工程化

缺点:

  • 效果和论文差距大
  • 有些实现细节可能不准确
  • 缺少完整的评估代码

作为学习材料是够用的。想做研究的话,可能需要换到 GenieRedux 这样更完整的框架。

读完代码的收获

读完这 1000 行代码,你应该能理解:

  1. 世界模型不是一个模型,而是三个模型的组合
  2. 离散化是关键步骤,让 Transformer 能处理视频
  3. MaskGIT 比自回归更适合图像生成
  4. 隐动作让无标签学习成为可能

这些理解比能不能跑出好效果更重要。

有了这个基础,再去看 Genie 3 的技术博客,或者读 GenieRedux 的完整代码,会容易很多。

posted @ 2026-02-02 10:38  147API  阅读(5)  评论(0)    收藏  举报