Lucidrains-系列项目源码解析-二十四-
Lucidrains 系列项目源码解析(二十四)
.\lucidrains\magvit2-pytorch\magvit2_pytorch\optimizer.py
# 从 torch.optim 模块中导入 AdamW 和 Adam 优化器
from torch.optim import AdamW, Adam
# 将参数分为需要权重衰减和不需要权重衰减的两个列表
def separate_weight_decayable_params(params):
wd_params, no_wd_params = [], []
# 遍历参数列表,根据参数的维度将参数分别添加到对应的列表中
for param in params:
param_list = no_wd_params if param.ndim < 2 else wd_params
param_list.append(param)
return wd_params, no_wd_params
# 获取优化器
def get_optimizer(
params,
lr = 1e-4,
wd = 1e-2,
betas = (0.9, 0.99),
eps = 1e-8,
filter_by_requires_grad = False,
group_wd_params = True,
**kwargs
):
# 如果需要根据 requires_grad 过滤参数
if filter_by_requires_grad:
params = [t for t in params if t.requires_grad]
# 设置优化器的参数
opt_kwargs = dict(lr = lr, betas = betas, eps = eps)
# 如果权重衰减为 0,则返回 Adam 优化器
if wd == 0:
return Adam(params, **opt_kwargs)
# 设置权重衰减参数
opt_kwargs = {'weight_decay': wd, **opt_kwargs}
# 如果不对权重衰减参数进行分组,则返回 AdamW 优化器
if not group_wd_params:
return AdamW(params, **opt_kwargs)
# 将参数分为需要权重衰减和不需要权重衰减的两个列表
wd_params, no_wd_params = separate_weight_decayable_params(params)
# 组合参数列表,分别设置权重衰减
params = [
{'params': wd_params},
{'params': no_wd_params, 'weight_decay': 0},
]
return AdamW(params, **opt_kwargs)
.\lucidrains\magvit2-pytorch\magvit2_pytorch\trainer.py
# 导入必要的库
from pathlib import Path
from functools import partial
from contextlib import contextmanager, nullcontext
import torch
from torch import nn
from torch.nn import Module
from torch.utils.data import Dataset, random_split
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
import pytorch_warmup as warmup
from beartype import beartype
from beartype.typing import Optional, Literal, Union, Type
from magvit2_pytorch.optimizer import get_optimizer
from magvit2_pytorch.magvit2_pytorch import VideoTokenizer
from magvit2_pytorch.data import (
VideoDataset,
ImageDataset,
DataLoader,
video_tensor_to_gif
)
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
from einops import rearrange
from ema_pytorch import EMA
from pytorch_custom_utils import auto_unwrap_model
# 定义常量
VideosOrImagesLiteral = Union[
Literal['videos'],
Literal['images']
]
ConstantLRScheduler = partial(LambdaLR, lr_lambda = lambda step: 1.)
DEFAULT_DDP_KWARGS = DistributedDataParallelKwargs(
find_unused_parameters = True
)
# 定义辅助函数
def exists(v):
return v is not None
def cycle(dl):
while True:
for data in dl:
yield data
# 定义类
@auto_unwrap_model()
class VideoTokenizerTrainer:
@beartype
def __init__(
self,
model: VideoTokenizer,
*,
batch_size: int,
num_train_steps: int,
learning_rate: float = 1e-5,
grad_accum_every: int = 1,
apply_gradient_penalty_every: int = 4,
max_grad_norm: Optional[float] = None,
dataset: Optional[Dataset] = None,
dataset_folder: Optional[str] = None,
dataset_type: VideosOrImagesLiteral = 'videos',
checkpoints_folder = './checkpoints',
results_folder = './results',
random_split_seed = 42,
valid_frac = 0.05,
validate_every_step = 100,
checkpoint_every_step = 100,
num_frames = 17,
use_wandb_tracking = False,
discr_start_after_step = 0.,
warmup_steps = 1000,
scheduler: Optional[Type[LRScheduler]] = None,
scheduler_kwargs: dict = dict(),
accelerate_kwargs: dict = dict(),
ema_kwargs: dict = dict(),
optimizer_kwargs: dict = dict(),
dataset_kwargs: dict = dict()
@contextmanager
@beartype
def trackers(
self,
project_name: str,
run_name: Optional[str] = None,
hps: Optional[dict] = None
):
assert self.use_wandb_tracking
self.accelerator.init_trackers(project_name, config = hps)
if exists(run_name):
self.accelerator.trackers[0].run.name = run_name
yield
self.accelerator.end_training()
def log(self, **data_kwargs):
self.accelerator.log(data_kwargs, step = self.step)
@property
def device(self):
return self.model.device
@property
def is_main(self):
return self.accelerator.is_main_process
@property
def is_local_main(self):
return self.accelerator.is_local_main_process
def wait(self):
return self.accelerator.wait_for_everyone()
def print(self, msg):
return self.accelerator.print(msg)
@property
def ema_tokenizer(self):
return self.ema_model.ema_model
def tokenize(self, *args, **kwargs):
return self.ema_tokenizer.tokenize(*args, **kwargs)
# 保存模型参数到指定路径
def save(self, path, overwrite = True):
# 将路径转换为 Path 对象
path = Path(path)
# 如果 overwrite 为 False,则要求路径不存在
assert overwrite or not path.exists()
# 构建保存的模型参数字典
pkg = dict(
model = self.model.state_dict(),
ema_model = self.ema_model.state_dict(),
optimizer = self.optimizer.state_dict(),
discr_optimizer = self.discr_optimizer.state_dict(),
warmup = self.warmup.state_dict(),
scheduler = self.scheduler.state_dict(),
discr_warmup = self.discr_warmup.state_dict(),
discr_scheduler = self.discr_scheduler.state_dict(),
step = self.step
)
# 保存多尺度判别器优化器的参数
for ind, opt in enumerate(self.multiscale_discr_optimizers):
pkg[f'multiscale_discr_optimizer_{ind}'] = opt.state_dict()
# 使用 torch.save 保存模型参数到指定路径
torch.save(pkg, str(path))
# 加载模型参数
def load(self, path):
# 将路径转换为 Path 对象
path = Path(path)
# 要求路径存在
assert path.exists()
# 加载模型参数字典
pkg = torch.load(str(path))
# 加载模型参数到对应的模型、优化器等对象中
self.model.load_state_dict(pkg['model'])
self.ema_model.load_state_dict(pkg['ema_model'])
self.optimizer.load_state_dict(pkg['optimizer'])
self.discr_optimizer.load_state_dict(pkg['discr_optimizer'])
self.warmup.load_state_dict(pkg['warmup'])
self.scheduler.load_state_dict(pkg['scheduler'])
self.discr_warmup.load_state_dict(pkg['discr_warmup'])
self.discr_scheduler.load_state_dict(pkg['discr_scheduler'])
# 加载多尺度判别器优化器的参数
for ind, opt in enumerate(self.multiscale_discr_optimizers):
opt.load_state_dict(pkg[f'multiscale_discr_optimizer_{ind}'])
# 加载步数
self.step = pkg['step']
# 执行验证步骤
@torch.no_grad()
def valid_step(
self,
dl_iter,
save_recons = True,
num_save_recons = 1
):
# 将 EMA 模型设置为评估模式
self.ema_model.eval()
# 初始化重建损失
recon_loss = 0.
ema_recon_loss = 0.
# 初始化有效视频和重建视频列表
valid_videos = []
recon_videos = []
# 循环执行梯度��积次数
for _ in range(self.grad_accum_every):
# 从数据迭代器中获取有效视频数据
valid_video, = next(dl_iter)
valid_video = valid_video.to(self.device)
# 使用自动混合精度计算损失
with self.accelerator.autocast():
loss, _ = self.model(valid_video, return_recon_loss_only = True)
ema_loss, ema_recon_video = self.ema_model(valid_video, return_recon_loss_only = True)
# 累积重建损失
recon_loss += loss / self.grad_accum_every
ema_recon_loss += ema_loss / self.grad_accum_every
# 调整视频维度
if valid_video.ndim == 4:
valid_video = rearrange(valid_video, 'b c h w -> b c 1 h w')
# 将有效视频和重建视频添加到列表中
valid_videos.append(valid_video.cpu())
recon_videos.append(ema_recon_video.cpu())
# 记录验证重建损失和 EMA 重建损失
self.log(
valid_recon_loss = recon_loss.item(),
valid_ema_recon_loss = ema_recon_loss.item()
)
# 打印验证重建损失和 EMA 重建损失
self.print(f'validation recon loss {recon_loss:.3f}')
self.print(f'validation EMA recon loss {ema_recon_loss:.3f}')
# 如果需要保存重建视频
if not save_recons:
return
# 合并有效视频和重建视频
valid_videos = torch.cat(valid_videos)
recon_videos = torch.cat(recon_videos)
# 将重建视频像素值限制在 0 到 1 之间
recon_videos.clamp_(min = 0., max = 1.)
# 选择指定数量的有效视频和重建视频
valid_videos, recon_videos = map(lambda t: t[:num_save_recons], (valid_videos, recon_videos))
# 重排有效视频和重建视频的维度
real_and_recon = rearrange([valid_videos, recon_videos], 'n b c f h w -> c f (b h) (n w)')
# 生成 GIF 文件保存路径
validate_step = self.step // self.validate_every_step
sample_path = str(self.results_folder / f'sampled.{validate_step}.gif')
# 将视频张量保存为 GIF 文件
video_tensor_to_gif(real_and_recon, str(sample_path))
# 打印保存的样本路径
self.print(f'sample saved to {str(sample_path)}')
# 定义训练方法
def train(self):
# 获取当前步数
step = self.step
# 创建数据加载器的循环迭代器
dl_iter = cycle(self.dataloader)
valid_dl_iter = cycle(self.valid_dataloader)
# 当步数小于总训练步数时循环执行以下操作
while step < self.num_train_steps:
# 打印当前步数
self.print(f'step {step}')
# 执行训练步骤
self.train_step(dl_iter)
# 等待
# 如果是主进程且当前步数是验证间隔的倍数时
if self.is_main and not (step % self.validate_every_step):
# 执行验证步骤
self.valid_step(valid_dl_iter)
# 等待
# 如果是主进程且当前步数是保存检查点间隔的倍数时
if self.is_main and not (step % self.checkpoint_every_step):
# 计算检查点编号
checkpoint_num = step // self.checkpoint_every_step
# 检查点路径
checkpoint_path = self.checkpoints_folder / f'checkpoint.{checkpoint_num}.pt'
# 保存检查点
self.save(str(checkpoint_path))
# 等待
# 步数加一
step += 1
.\lucidrains\magvit2-pytorch\magvit2_pytorch\version.py
# 定义变量 __version__,赋值为字符串 '0.4.0'
__version__ = '0.4.0'
.\lucidrains\magvit2-pytorch\magvit2_pytorch\__init__.py
# 从 magvit2_pytorch 包中导入 MagViT2 和 VideoTokenizer 类
from magvit2_pytorch.magvit2_pytorch import (
MagViT2,
VideoTokenizer
)
# 从 magvit2_pytorch 包中导入 VideoTokenizerTrainer 类
from magvit2_pytorch.trainer import (
VideoTokenizerTrainer
)

MagViT2 - Pytorch
Implementation of MagViT2 from Language Model Beats Diffusion - Tokenizer is Key to Visual Generation in Pytorch. This currently holds SOTA for video generation / understanding.
The Lookup Free Quantizer proposed in the paper can be found in a separate repository. It should probably be explored for all other modalities, starting with audio
Please join if you are interested in replicating the tokenizer proposed in this paper out in the open
Appreciation
-
StabilityAI and 🤗 Huggingface for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.
-
Louis Serrano for sharing some early initial runs, validating that the overall architecture converges with finite scalar quantization.
-
You? If you are a talented research engineer / scientist, feel free to contribute to cutting edge open source science!
Install
$ pip install magvit2-pytorch
Usage
from magvit2_pytorch import (
VideoTokenizer,
VideoTokenizerTrainer
)
tokenizer = VideoTokenizer(
image_size = 128,
init_dim = 64,
max_dim = 512,
codebook_size = 1024,
layers = (
'residual',
'compress_space',
('consecutive_residual', 2),
'compress_space',
('consecutive_residual', 2),
'linear_attend_space',
'compress_space',
('consecutive_residual', 2),
'attend_space',
'compress_time',
('consecutive_residual', 2),
'compress_time',
('consecutive_residual', 2),
'attend_time',
)
)
trainer = VideoTokenizerTrainer(
tokenizer,
dataset_folder = '/path/to/a/lot/of/media', # folder of either videos or images, depending on setting below
dataset_type = 'videos', # 'videos' or 'images', prior papers have shown pretraining on images to be effective for video synthesis
batch_size = 4,
grad_accum_every = 8,
learning_rate = 2e-5,
num_train_steps = 1_000_000
)
trainer.train()
# after a lot of training ...
# can use the EMA of the tokenizer
ema_tokenizer = trainer.ema_tokenizer
# mock video
video = torch.randn(1, 3, 17, 128, 128)
# tokenizing video to discrete codes
codes = ema_tokenizer.tokenize(video) # (1, 9, 16, 16) <- in this example, time downsampled by 4x and space downsampled by 8x. flatten token ids for (non)-autoregressive training
# sanity check
decoded_video = ema_tokenizer.decode_from_code_indices(codes)
assert torch.allclose(
decoded_video,
ema_tokenizer(video, return_recon = True)
)
To track your experiments on Weights & Biases set use_wandb_tracking = True on VideoTokenizerTrainer, and then use the .trackers context manager
trainer = VideoTokenizerTrainer(
use_wandb_tracking = True,
...
)
with trainer.trackers(project_name = 'magvit2', run_name = 'baseline'):
trainer.train()
Todo
Citations
@misc{yu2023language,
title = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation},
author = {Lijun Yu and José Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang},
year = {2023},
eprint = {2310.05737},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
@article{Zhang2021TokenST,
title = {Token Shift Transformer for Video Classification},
author = {Hao Zhang and Y. Hao and Chong-Wah Ngo},
journal = {Proceedings of the 29th ACM International Conference on Multimedia},
year = {2021}
}
@inproceedings{Arora2023ZoologyMA,
title = {Zoology: Measuring and Improving Recall in Efficient Language Models},
author = {Simran Arora and Sabri Eyuboglu and Aman Timalsina and Isys Johnson and Michael Poli and James Zou and Atri Rudra and Christopher R'e},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:266149332}
}
.\lucidrains\magvit2-pytorch\setup.py
# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 执行版本文件中的代码,将版本信息导入当前环境
exec(open('magvit2_pytorch/version.py').read())
# 设置包的元数据
setup(
# 包名
name = 'magvit2-pytorch',
# 查找所有包
packages = find_packages(),
# 版本号
version = __version__,
# 许可证
license='MIT',
# 描述
description = 'MagViT2 - Pytorch',
# 长描述内容类型
long_description_content_type = 'text/markdown',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 项目链接
url = 'https://github.com/lucidrains/magvit2-pytorch',
# 关键词
keywords = [
'artificial intelligence',
'deep learning',
'transformer',
'attention mechanisms',
'generative video model'
],
# 安装依赖
install_requires=[
'accelerate>=0.24.0',
'beartype',
'einops>=0.7.0',
'ema-pytorch>=0.2.4',
'pytorch-warmup',
'gateloop-transformer>=0.2.2',
'kornia',
'opencv-python',
'pillow',
'pytorch-custom-utils>=0.0.9',
'numpy',
'vector-quantize-pytorch>=1.11.8',
'taylor-series-linear-attention>=0.1.5',
'torch',
'torchvision',
'x-transformers'
],
# 分类
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\make-a-video-pytorch\make_a_video_pytorch\attend.py
# 导入必要的库
from functools import wraps
from packaging import version
from collections import namedtuple
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
# 定义一个命名元组,用于存储注意力机制的配置信息
AttentionConfig = namedtuple('AttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
# 定义一个辅助函数,用于检查变量是否存在
def exists(val):
return val is not None
# 定义一个装饰器函数,确保被装饰的函数只执行一次
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
# 用装饰器once包装print函数,确保只打印一次
print_once = once(print)
# 主要类定义
class Attend(nn.Module):
def __init__(
self,
dropout = 0.,
flash = False,
causal = False
):
super().__init__()
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
self.causal = causal
self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# 确定在cuda和cpu上的高效注意力配置
self.cpu_config = AttentionConfig(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = AttentionConfig(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = AttentionConfig(False, True, True)
# 实现flash attention
def flash_attn(self, q, k, v):
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
q, k, v = map(lambda t: t.contiguous(), (q, k, v))
# 检查是否有兼容的设备用于flash attention
config = self.cuda_config if is_cuda else self.cpu_config
# 使用torch.backends.cuda.sdp_kernel(**config._asdict())来调用pytorch 2.0的flash attention
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
dropout_p = self.dropout if self.training else 0.,
is_causal = self.causal
)
return out
# 前向传播函数
def forward(self, q, k, v, bias = None):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
if self.flash:
assert not exists(bias)
return self.flash_attn(q, k, v)
scale = q.shape[-1] ** -0.5
# 相似度计算
sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
# 注意力偏置
if exists(bias):
sim = sim + bias
# 因果关系
if self.causal:
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# 注意力计算
attn = sim.softmax(dim = -1)
attn = self.attn_dropout(attn)
# 聚合值
out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
return out
.\lucidrains\make-a-video-pytorch\make_a_video_pytorch\make_a_video.py
# 导入数学库
import math
# 导入 functools 库
import functools
# 从 operator 库中导入 mul 函数
from operator import mul
# 导入 torch 库
import torch
# 从 torch.nn 中导入 functional 模块
import torch.nn.functional as F
# 从 torch 中导入 nn、einsum 模块
from torch import nn, einsum
# 从 einops 中导入 rearrange、repeat、pack、unpack 函数,以及 Rearrange 类
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
# 从 make_a_video_pytorch.attend 模块中导入 Attend 类
# 辅助函数
# 判断变量是否存在
def exists(val):
return val is not None
# 如果变量存在,则返回变量值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 对元组中的元素进行乘法运算
def mul_reduce(tup):
return functools.reduce(mul, tup)
# 判断一个数是否可以被另一个数整除
def divisible_by(numer, denom):
return (numer % denom) == 0
# 创建 nn.ModuleList 对象
mlist = nn.ModuleList
# 用于时间条件
# 正弦位置编码
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim, theta = 10000):
super().__init__()
self.theta = theta
self.dim = dim
def forward(self, x):
dtype, device = x.dtype, x.device
assert dtype == torch.float, 'input to sinusoidal pos emb must be a float type'
half_dim = self.dim // 2
emb = math.log(self.theta) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim = -1).type(dtype)
# 3D 归一化
# RMS 归一化
class RMSNorm(nn.Module):
def __init__(self, chan, dim = 1):
super().__init__()
self.dim = dim
self.gamma = nn.Parameter(torch.ones(chan))
def forward(self, x):
dim = self.dim
right_ones = (dim + 1) if dim < 0 else (x.ndim - 1 - dim)
gamma = self.gamma.reshape(-1, *((1,) * right_ones))
return F.normalize(x, dim = dim) * (x.shape[dim] ** 0.5) * gamma
# 前馈网络
# 移位令牌
def shift_token(t):
t, t_shift = t.chunk(2, dim = 1)
t_shift = F.pad(t_shift, (0, 0, 0, 0, 1, -1), value = 0.)
return torch.cat((t, t_shift), dim = 1)
# GEGLU 激活函数
class GEGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim = 1)
return x * F.gelu(gate)
# 前馈网络
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4):
super().__init__()
inner_dim = int(dim * mult * 2 / 3)
self.proj_in = nn.Sequential(
nn.Conv3d(dim, inner_dim * 2, 1, bias = False),
GEGLU()
)
self.proj_out = nn.Sequential(
RMSNorm(inner_dim),
nn.Conv3d(inner_dim, dim, 1, bias = False)
)
def forward(self, x, enable_time = True):
is_video = x.ndim == 5
enable_time &= is_video
if not is_video:
x = rearrange(x, 'b c h w -> b c 1 h w')
x = self.proj_in(x)
if enable_time:
x = shift_token(x)
out = self.proj_out(x)
if not is_video:
out = rearrange(out, 'b c 1 h w -> b c h w')
return out
# 最佳相对位置编码
# 连续位置偏置
class ContinuousPositionBias(nn.Module):
""" from https://arxiv.org/abs/2111.09883 """
def __init__(
self,
*,
dim,
heads,
num_dims = 1,
layers = 2
):
super().__init__()
self.num_dims = num_dims
self.net = nn.ModuleList([])
self.net.append(nn.Sequential(nn.Linear(self.num_dims, dim), nn.SiLU()))
for _ in range(layers - 1):
self.net.append(nn.Sequential(nn.Linear(dim, dim), nn.SiLU()))
self.net.append(nn.Linear(dim, heads)
@property
def device(self):
return next(self.parameters()).device
# 定义一个前向传播函数,接受多个维度参数
def forward(self, *dimensions):
# 获取当前设备
device = self.device
# 将维度转换为张量
shape = torch.tensor(dimensions, device=device)
# 计算相对位置的形状
rel_pos_shape = 2 * shape - 1
# 计算步长
# 将相对位置形状进行翻转,并计算累积乘积
strides = torch.flip(rel_pos_shape, (0,)).cumprod(dim=-1)
# 在步长张量两端填充1,并再次翻转
strides = torch.flip(F.pad(strides, (1, -1), value=1), (0,))
# 获取所有位置并计算所有相对距离
# 生成每个维度的位置张量
positions = [torch.arange(d, device=device) for d in dimensions]
# 创建网格坐标
grid = torch.stack(torch.meshgrid(*positions, indexing='ij'), dim=-1)
# 重新排列网格坐标
grid = rearrange(grid, '... c -> (...) c')
# 计算相对距离
rel_dist = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')
# 获取所有维度上的相对位置
# 生成每个维度上的相对位置张量
rel_positions = [torch.arange(-d + 1, d, device=device) for d in dimensions]
# 创建相对位置网格
rel_pos_grid = torch.stack(torch.meshgrid(*rel_positions, indexing='ij'), dim=-1)
# 重新排列相对位置网格
rel_pos_grid = rearrange(rel_pos_grid, '... c -> (...) c')
# MLP 输入
# 将相对位置网格转换为浮点数
bias = rel_pos_grid.float()
# 遍历网络的每一层
for layer in self.net:
# 将相对位置网格传入每一层
bias = layer(bias)
# 将相对距离转换为偏置的索引
# 将相对距离加上形状减一确保为正数
rel_dist += (shape - 1)
# 乘以步长
rel_dist *= strides
# 沿着最后一个维度求和,得到索引
rel_dist_indices = rel_dist.sum(dim=-1)
# 选择每个唯一相对位置组合的偏置
# 根据索引选择偏置
bias = bias[rel_dist_indices]
# 重新排列偏置
return rearrange(bias, 'i j h -> h i j')
# 定义注意力机制类
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head = 64,
heads = 8,
flash = False,
causal = False
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
# 创建 Attend 对象
self.attend = Attend(flash = flash, causal = causal)
# 创建 RMSNorm 对象
self.norm = RMSNorm(dim, dim = -1)
# 创建线性变换层
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
# 初始化权重为零,实现跳跃连接
nn.init.zeros_(self.to_out.weight.data)
def forward(
self,
x,
rel_pos_bias = None
):
x = self.norm(x)
q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
out = self.attend(q, k, v, bias = rel_pos_bias)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# 定义主要贡献 - 伪 3D 卷积类
class PseudoConv3d(nn.Module):
def __init__(
self,
dim,
dim_out = None,
kernel_size = 3,
*,
temporal_kernel_size = None,
**kwargs
):
super().__init__()
dim_out = default(dim_out, dim)
temporal_kernel_size = default(temporal_kernel_size, kernel_size)
# 创建空间卷积层和时间卷积层
self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size = kernel_size, padding = kernel_size // 2)
self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size = temporal_kernel_size, padding = temporal_kernel_size // 2) if kernel_size > 1 else None
# 初始化时间卷积层的权重为单位矩阵,偏置为零
if exists(self.temporal_conv):
nn.init.dirac_(self.temporal_conv.weight.data)
nn.init.zeros_(self.temporal_conv.bias.data)
def forward(
self,
x,
enable_time = True
):
b, c, *_, h, w = x.shape
is_video = x.ndim == 5
enable_time &= is_video
if is_video:
x = rearrange(x, 'b c f h w -> (b f) c h w')
x = self.spatial_conv(x)
if is_video:
x = rearrange(x, '(b f) c h w -> b c f h w', b = b)
if not enable_time or not exists(self.temporal_conv):
return x
x = rearrange(x, 'b c f h w -> (b h w) c f')
x = self.temporal_conv(x)
x = rearrange(x, '(b h w) c f -> b c f h w', h = h, w = w)
return x
# 定义分解的时空注意力类
class SpatioTemporalAttention(nn.Module):
def __init__(
self,
dim,
*,
dim_head = 64,
heads = 8,
add_feed_forward = True,
ff_mult = 4,
pos_bias = True,
flash = False,
causal_time_attn = False
):
super().__init__()
assert not (flash and pos_bias), 'learned positional attention bias is not compatible with flash attention'
# 创建空间注意力和时间注意力对象
self.spatial_attn = Attention(dim = dim, dim_head = dim_head, heads = heads, flash = flash)
self.spatial_rel_pos_bias = ContinuousPositionBias(dim = dim // 2, heads = heads, num_dims = 2) if pos_bias else None
self.temporal_attn = Attention(dim = dim, dim_head = dim_head, heads = heads, flash = flash, causal = causal_time_attn)
self.temporal_rel_pos_bias = ContinuousPositionBias(dim = dim // 2, heads = heads, num_dims = 1) if pos_bias else None
self.has_feed_forward = add_feed_forward
if not add_feed_forward:
return
# 创建前馈网络对象
self.ff = FeedForward(dim = dim, mult = ff_mult)
def forward(
self,
x,
enable_time = True
):
# 从输入张量 x 的形状中提取出 b, c, h, w,*_, h, w 表示忽略中间的维度,只取最后两个维度
b, c, *_, h, w = x.shape
# 判断输入张量是否为视频,即维度是否为 5
is_video = x.ndim == 5
# 更新 enable_time 变量,如果是视频则为 True
enable_time &= is_video
# 根据输入张量的维度不同进行不同的重排操作
if is_video:
x = rearrange(x, 'b c f h w -> (b f) (h w) c')
else:
x = rearrange(x, 'b c h w -> b (h w) c')
# 如果存在空间相对位置偏置函数,则计算空间相对位置偏置
space_rel_pos_bias = self.spatial_rel_pos_bias(h, w) if exists(self.spatial_rel_pos_bias) else None
# 对输入张量进行空间注意力操作,并加上原始输入张量
x = self.spatial_attn(x, rel_pos_bias = space_rel_pos_bias) + x
# 根据输入张量的维度不同进行不同的重排操作,恢复原始形状
if is_video:
x = rearrange(x, '(b f) (h w) c -> b c f h w', b = b, h = h, w = w)
else:
x = rearrange(x, 'b (h w) c -> b c h w', h = h, w = w)
# 如果 enable_time 为 True,则进行时间维度的处理
if enable_time:
# 对输入张量进行时间维度的重排操作
x = rearrange(x, 'b c f h w -> (b h w) f c')
# 如果存在时间相对位置偏置函数,则计算时间相对位置偏置
time_rel_pos_bias = self.temporal_rel_pos_bias(x.shape[1]) if exists(self.temporal_rel_pos_bias) else None
# 对输入张量进行时间注意力操作,并加上原始输入张量
x = self.temporal_attn(x, rel_pos_bias = time_rel_pos_bias) + x
# 恢复原始形状
x = rearrange(x, '(b h w) f c -> b c f h w', w = w, h = h)
# 如果存在前馈网络,则对输入张量进行前馈操作,并加上原始输入张量
if self.has_feed_forward:
x = self.ff(x, enable_time = enable_time) + x
# 返回处理后的张量
return x
# 定义 ResNet 块
class Block(nn.Module):
def __init__(
self,
dim,
dim_out,
kernel_size = 3,
temporal_kernel_size = None,
groups = 8
):
super().__init__()
# 创建伪 3D 卷积层
self.project = PseudoConv3d(dim, dim_out, 3)
# 添加 Group Normalization
self.norm = nn.GroupNorm(groups, dim_out)
# 添加 SiLU 激活函数
self.act = nn.SiLU()
def forward(
self,
x,
scale_shift = None,
enable_time = False
):
# 对输入进行卷积操作
x = self.project(x, enable_time = enable_time)
# 对卷积结果进行归一化
x = self.norm(x)
# 如果存在 scale_shift 参数,则进行缩放和平移操作
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
return self.act(x)
# 定义 ResNet 块
class ResnetBlock(nn.Module):
def __init__(
self,
dim,
dim_out,
*,
timestep_cond_dim = None,
groups = 8
):
super().__init__()
self.timestep_mlp = None
# 如果存在时间步条件维度,则创建 MLP 网络
if exists(timestep_cond_dim):
self.timestep_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(timestep_cond_dim, dim_out * 2)
)
# 创建两个 Block 实例
self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
# 如果输入维度和输出维度不同,创建伪 3D 卷积层,否则创建恒等映射
self.res_conv = PseudoConv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(
self,
x,
timestep_emb = None,
enable_time = True
):
# 断言时间步条件嵌入和时间步 MLP 是否同时存在
assert not (exists(timestep_emb) ^ exists(self.timestep_mlp))
scale_shift = None
# 如果存在时间步 MLP 和时间步嵌入,则进行处理
if exists(self.timestep_mlp) and exists(timestep_emb):
time_emb = self.timestep_mlp(timestep_emb)
to_einsum_eq = 'b c 1 1 1' if x.ndim == 5 else 'b c 1 1'
time_emb = rearrange(time_emb, f'b c -> {to_einsum_eq}')
scale_shift = time_emb.chunk(2, dim = 1)
# 对输入进行第一个 Block 处理
h = self.block1(x, scale_shift = scale_shift, enable_time = enable_time)
# 对第一�� Block 处理结果进行第二个 Block 处理
h = self.block2(h, enable_time = enable_time)
return h + self.res_conv(x)
# 像素混洗上采样和下采样,其中时间维度可以配置
# 定义下采样模块
class Downsample(nn.Module):
def __init__(
self,
dim,
downsample_space = True,
downsample_time = False,
nonlin = False
):
super().__init__()
assert downsample_space or downsample_time
# 如果需要空间下采样,则创建相应的模块
self.down_space = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2),
nn.Conv2d(dim * 4, dim, 1, bias = False),
nn.SiLU() if nonlin else nn.Identity()
) if downsample_space else None
# 如果需要时间下采样,则创建相应的模块
self.down_time = nn.Sequential(
Rearrange('b c (f p) h w -> b (c p) f h w', p = 2),
nn.Conv3d(dim * 2, dim, 1, bias = False),
nn.SiLU() if nonlin else nn.Identity()
) if downsample_time else None
def forward(
self,
x,
enable_time = True
):
is_video = x.ndim == 5
if is_video:
x = rearrange(x, 'b c f h w -> b f c h w')
x, ps = pack([x], '* c h w')
# 如果存在空间下采样模块,则进行处理
if exists(self.down_space):
x = self.down_space(x)
if is_video:
x, = unpack(x, ps, '* c h w')
x = rearrange(x, 'b f c h w -> b c f h w')
# 如果不是视频或者不存在时间下采样模块或者不启用时间,则直接返回结果
if not is_video or not exists(self.down_time) or not enable_time:
return x
# 如果需要时间下采样,则进行处理
x = self.down_time(x)
return x
# 定义上采样模块
class Upsample(nn.Module):
def __init__(
self,
dim,
upsample_space = True,
upsample_time = False,
nonlin = False
# 定义一个类,继承自 nn.Module
):
# 调用父类的初始化方法
super().__init__()
# 断言是否需要上采样空间或时间
assert upsample_space or upsample_time
# 如果需要上采样空间,则定义空间上采样的操作
self.up_space = nn.Sequential(
nn.Conv2d(dim, dim * 4, 1), # 使用 1x1 卷积进行通道扩展
nn.SiLU() if nonlin else nn.Identity(), # 使用 SiLU 激活函数或者恒等映射
Rearrange('b (c p1 p2) h w -> b c (h p1) (w p2)', p1 = 2, p2 = 2) # 重新排列张量维度
) if upsample_space else None
# 如果需要上采样时间,则定义时间上采样的操作
self.up_time = nn.Sequential(
nn.Conv3d(dim, dim * 2, 1), # 使用 1x1x1 卷积进行通道扩展
nn.SiLU() if nonlin else nn.Identity(), # 使用 SiLU 激活函数或者恒等映射
Rearrange('b (c p) f h w -> b c (f p) h w', p = 2) # 重新排列张量维度
) if upsample_time else None
# 初始化函数
self.init_()
# 初始化函数
def init_(self):
# 如果存在空间上采样操作,则初始化空间上采样的卷积层
if exists(self.up_space):
self.init_conv_(self.up_space[0], 4)
# 如果存在时间上采样操作,则初始化时间上采样的卷积层
if exists(self.up_time):
self.init_conv_(self.up_time[0], 2)
# 初始化卷积层的权重
def init_conv_(self, conv, factor):
o, *remain_dims = conv.weight.shape
conv_weight = torch.empty(o // factor, *remain_dims)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, 'o ... -> (o r) ...', r = factor)
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
# 前向传播函数
def forward(
self,
x,
enable_time = True
):
# 判断输入是否为视频
is_video = x.ndim == 5
# 如果是视频,则重新排列张量维度
if is_video:
x = rearrange(x, 'b c f h w -> b f c h w')
x, ps = pack([x], '* c h w')
# 如果存在空间上采样操作,则进行空间上采样
if exists(self.up_space):
x = self.up_space(x)
# 如果是视频,则恢复原始张量维度
if is_video:
x, = unpack(x, ps, '* c h w')
x = rearrange(x, 'b f c h w -> b c f h w')
# 如果不是视频或者不存在时间上采样���作或者不启用时间上采样,则直接返回结果
if not is_video or not exists(self.up_time) or not enable_time:
return x
# 进行时间上采样
x = self.up_time(x)
return x
# space time factorized 3d unet
class SpaceTimeUnet(nn.Module):
def __init__(
self,
*,
dim, # 维度
channels = 3, # 通道数,默认为3
dim_mult = (1, 2, 4, 8), # 维度倍增因子
self_attns = (False, False, False, True), # 是否使用自注意力机制
temporal_compression = (False, True, True, True), # 是否进行时间压缩
resnet_block_depths = (2, 2, 2, 2), # ResNet块的深度
attn_dim_head = 64, # 注意力机制的头数
attn_heads = 8, # 注意力头数
condition_on_timestep = True, # 是否在时间步上进行条件化
attn_pos_bias = True, # 是否使用位置偏置
flash_attn = False, # 是否使用快闪注意力
causal_time_attn = False # 是否使用因果时间注意力
):
super().__init__()
assert len(dim_mult) == len(self_attns) == len(temporal_compression) == len(resnet_block_depths)
num_layers = len(dim_mult)
dims = [dim, *map(lambda mult: mult * dim, dim_mult)] # 计算每一层的维度
dim_in_out = zip(dims[:-1], dims[1:])
# determine the valid multiples of the image size and frames of the video
self.frame_multiple = 2 ** sum(tuple(map(int, temporal_compression))) # 计算视频帧数的倍数
self.image_size_multiple = 2 ** num_layers # 计算图像大小的倍数
# timestep conditioning for DDPM, not to be confused with the time dimension of the video
self.to_timestep_cond = None
timestep_cond_dim = (dim * 4) if condition_on_timestep else None
if condition_on_timestep:
self.to_timestep_cond = nn.Sequential(
SinusoidalPosEmb(dim), # 添加正弦位置编码
nn.Linear(dim, timestep_cond_dim), # 线性变换
nn.SiLU() # 激活函数
)
# layers
self.downs = mlist([]) # 下采样层
self.ups = mlist([]) # 上采样层
attn_kwargs = dict(
dim_head = attn_dim_head,
heads = attn_heads,
pos_bias = attn_pos_bias,
flash = flash_attn,
causal_time_attn = causal_time_attn
)
mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, timestep_cond_dim = timestep_cond_dim) # 中间块1
self.mid_attn = SpatioTemporalAttention(dim = mid_dim, **attn_kwargs) # 中间注意力机制
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, timestep_cond_dim = timestep_cond_dim) # 中间块2
for _, self_attend, (dim_in, dim_out), compress_time, resnet_block_depth in zip(range(num_layers), self_attns, dim_in_out, temporal_compression, resnet_block_depths):
assert resnet_block_depth >= 1
self.downs.append(mlist([
ResnetBlock(dim_in, dim_out, timestep_cond_dim = timestep_cond_dim), # 下采样块
mlist([ResnetBlock(dim_out, dim_out) for _ in range(resnet_block_depth)]), # ResNet块
SpatioTemporalAttention(dim = dim_out, **attn_kwargs) if self_attend else None, # 注意力机制
Downsample(dim_out, downsample_time = compress_time) # 下采样
]))
self.ups.append(mlist([
ResnetBlock(dim_out * 2, dim_in, timestep_cond_dim = timestep_cond_dim), # 上采样块
mlist([ResnetBlock(dim_in + (dim_out if ind == 0 else 0), dim_in) for ind in range(resnet_block_depth)]), # ResNet块
SpatioTemporalAttention(dim = dim_in, **attn_kwargs) if self_attend else None, # 注意力机制
Upsample(dim_out, upsample_time = compress_time) # 上采样
]))
self.skip_scale = 2 ** -0.5 # 论文显示更快的收敛速度
self.conv_in = PseudoConv3d(dim = channels, dim_out = dim, kernel_size = 7, temporal_kernel_size = 3) # 输入卷积层
self.conv_out = PseudoConv3d(dim = dim, dim_out = channels, kernel_size = 3, temporal_kernel_size = 3) # 输出卷积层
def forward(
self,
x,
timestep = None,
enable_time = True
):
# some asserts
# 断言条件:self.to_timestep_cond 和 timestep 存在性相同
assert not (exists(self.to_timestep_cond) ^ exists(timestep))
# 判断 x 是否为视频,维度是否为5
is_video = x.ndim == 5
# 如果启用时间和 x 是视频
if enable_time and is_video:
# 获取视频帧数
frames = x.shape[2]
# 断言条件:视频帧数必须能被 self.frame_multiple 整除
assert divisible_by(frames, self.frame_multiple), f'number of frames on the video ({frames}) must be divisible by the frame multiple ({self.frame_multiple})'
# 获取图片或视频的高度和宽度
height, width = x.shape[-2:]
# 断言条件:图片或视频的高度和宽度必须是 self.image_size_multiple 的倍数
assert divisible_by(height, self.image_size_multiple) and divisible_by(width, self.image_size_multiple), f'height and width of the image or video must be a multiple of {self.image_size_multiple}'
# main logic
# 如果 timestep 存在,则根据条件转换为 t
t = self.to_timestep_cond(rearrange(timestep, '... -> (...)')) if exists(timestep) else None
# 对输入 x 进行卷积操作
x = self.conv_in(x, enable_time = enable_time)
# 初始化 hiddens 列表
hiddens = []
# 遍历 downs 列表中的元素
for init_block, blocks, maybe_attention, downsample in self.downs:
# 对 x 进行初始化块操作
x = init_block(x, t, enable_time = enable_time)
# 将当前 x 添加到 hiddens 列表中
hiddens.append(x.clone())
# 遍历 blocks 列表中的元素
for block in blocks:
# 对 x 进行块操作
x = block(x, enable_time = enable_time)
# 如果 maybe_attention 存在,则对 x 进行注意力操作
if exists(maybe_attention):
x = maybe_attention(x, enable_time = enable_time)
# 将当前 x 添加到 hiddens 列表中
hiddens.append(x.clone())
# 对 x 进行下采样操作
x = downsample(x, enable_time = enable_time)
# 对 x 进行中间块1操作
x = self.mid_block1(x, t, enable_time = enable_time)
# 对 x 进行中间注意力操作
x = self.mid_attn(x, enable_time = enable_time)
# 对 x 进行中间块2操作
x = self.mid_block2(x, t, enable_time = enable_time)
# 遍历反转后的 ups 列表中的��素
for init_block, blocks, maybe_attention, upsample in reversed(self.ups):
# 对 x 进行上采样操作
x = upsample(x, enable_time = enable_time)
# 将 hiddens 列表中的元素与 x 进行拼接
x = torch.cat((hiddens.pop() * self.skip_scale, x), dim = 1)
# 对 x 进行初始化块操作
x = init_block(x, t, enable_time = enable_time)
# 将 hiddens 列表中的元素与 x 进行拼接
x = torch.cat((hiddens.pop() * self.skip_scale, x), dim = 1)
# 遍历 blocks 列表中的元素
for block in blocks:
# 对 x 进行块操作
x = block(x, enable_time = enable_time)
# 如果 maybe_attention 存在,则对 x 进行注意力操作
if exists(maybe_attention):
x = maybe_attention(x, enable_time = enable_time)
# 对 x 进行输出卷积操作
x = self.conv_out(x, enable_time = enable_time)
# 返回结果 x
return x
.\lucidrains\make-a-video-pytorch\make_a_video_pytorch\__init__.py
# 从 make_a_video_pytorch.make_a_video 模块中导入 PseudoConv3d, SpatioTemporalAttention 类
from make_a_video_pytorch.make_a_video import PseudoConv3d, SpatioTemporalAttention
# 从 make_a_video_pytorch.make_a_video 模块中导入 ResnetBlock, Downsample, Upsample 类
from make_a_video_pytorch.make_a_video import ResnetBlock, Downsample, Upsample
# 从 make_a_video_pytorch.make_a_video 模块中导入 SpaceTimeUnet 类
from make_a_video_pytorch.make_a_video import SpaceTimeUnet

Make-A-Video - Pytorch (wip)
Implementation of Make-A-Video, new SOTA text to video generator from Meta AI, in Pytorch. They combine pseudo-3d convolutions (axial convolutions) and temporal attention and show much better temporal fusion.
The pseudo-3d convolutions isn't a new concept. It has been explored before in other contexts, say for protein contact prediction as "dimensional hybrid residual networks".
The gist of the paper comes down to, take a SOTA text-to-image model (here they use DALL-E2, but the same learning points would easily apply to Imagen), make a few minor modifications for attention across time and other ways to skimp on the compute cost, do frame interpolation correctly, get a great video model out.
Appreciation
-
Stability.ai for the generous sponsorship to work on cutting edge artificial intelligence research
-
Jonathan Ho for bringing about a revolution in generative artificial intelligence through his seminal paper
-
Alex for einops, an abstraction that is simply genius. No other word for it.
Install
$ pip install make-a-video-pytorch
Usage
Passing in video features
import torch
from make_a_video_pytorch import PseudoConv3d, SpatioTemporalAttention
conv = PseudoConv3d(
dim = 256,
kernel_size = 3
)
attn = SpatioTemporalAttention(
dim = 256,
dim_head = 64,
heads = 8
)
video = torch.randn(1, 256, 8, 16, 16) # (batch, features, frames, height, width)
conv_out = conv(video) # (1, 256, 8, 16, 16)
attn_out = attn(video) # (1, 256, 8, 16, 16)
Passing in images (if one were to pretrain on images first), both temporal convolution and attention will be automatically skipped. In other words, you can use this straightforwardly in your 2d Unet and then port it over to a 3d Unet once that phase of the training is done. The temporal modules are initialized to output identity as the paper had done.
import torch
from make_a_video_pytorch import PseudoConv3d, SpatioTemporalAttention
conv = PseudoConv3d(
dim = 256,
kernel_size = 3
)
attn = SpatioTemporalAttention(
dim = 256,
dim_head = 64,
heads = 8
)
images = torch.randn(1, 256, 16, 16) # (batch, features, height, width)
conv_out = conv(images) # (1, 256, 16, 16)
attn_out = attn(images) # (1, 256, 16, 16)
You can also control the two modules so that when fed 3-dimensional features, it only does training spatially
import torch
from make_a_video_pytorch import PseudoConv3d, SpatioTemporalAttention
conv = PseudoConv3d(
dim = 256,
kernel_size = 3
)
attn = SpatioTemporalAttention(
dim = 256,
dim_head = 64,
heads = 8
)
video = torch.randn(1, 256, 8, 16, 16) # (batch, features, frames, height, width)
# below it will not train across time
conv_out = conv(video, enable_time = False) # (1, 256, 8, 16, 16)
attn_out = attn(video, enable_time = False) # (1, 256, 8, 16, 16)
Full SpaceTimeUnet that is agnostic to images or video training, and where even if video is passed in, time can be ignored
import torch
from make_a_video_pytorch import SpaceTimeUnet
unet = SpaceTimeUnet(
dim = 64,
channels = 3,
dim_mult = (1, 2, 4, 8),
resnet_block_depths = (1, 1, 1, 2),
temporal_compression = (False, False, False, True),
self_attns = (False, False, False, True),
condition_on_timestep = False,
attn_pos_bias = False,
flash_attn = True
).cuda()
# train on images
images = torch.randn(1, 3, 128, 128).cuda()
images_out = unet(images)
assert images.shape == images_out.shape
# then train on videos
video = torch.randn(1, 3, 16, 128, 128).cuda()
video_out = unet(video)
assert video_out.shape == video.shape
# or even treat your videos as images
video_as_images_out = unet(video, enable_time = False)
Todo
Citations
@misc{Singer2022,
author = {Uriel Singer},
url = {https://makeavideo.studio/Make-A-Video.pdf}
}
@inproceedings{rogozhnikov2022einops,
title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
author = {Alex Rogozhnikov},
booktitle = {International Conference on Learning Representations},
year = {2022},
url = {https://openreview.net/forum?id=oapKSVM2bcj}
}
@article{Dong2021AttentionIN,
title = {Attention is Not All You Need: Pure Attention Loses Rank Doubly Exponentially with Depth},
author = {Yihe Dong and Jean-Baptiste Cordonnier and Andreas Loukas},
journal = {ArXiv},
year = {2021},
volume = {abs/2103.03404}
}
@article{Zhang2021TokenST,
title = {Token Shift Transformer for Video Classification},
author = {Hao Zhang and Y. Hao and Chong-Wah Ngo},
journal = {Proceedings of the 29th ACM International Conference on Multimedia},
year = {2021}
}
@inproceedings{shleifer2022normformer,
title = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
author = {Sam Shleifer and Myle Ott},
booktitle = {Submitted to The Tenth International Conference on Learning Representations },
year = {2022},
url = {https://openreview.net/forum?id=GMYWzWztDx5},
}
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
.\lucidrains\make-a-video-pytorch\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的信息
setup(
name = 'make-a-video-pytorch', # 包的名称
packages = find_packages(exclude=[]), # 查找并包含所有包
version = '0.3.1', # 版本号
license='MIT', # 许可证
description = 'Make-A-Video - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/make-a-video-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'attention mechanism',
'text-to-video',
'axial convolutions'
],
install_requires=[ # 安装依赖
'classifier-free-guidance-pytorch',
'einops>=0.6',
'torch>=1.6',
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\MaMMUT-pytorch\mammut_pytorch\mammut_pytorch.py
# 导入 torch 库
import torch
# 从 torch 库中导入 einsum, nn 模块
from torch import einsum, nn
# 从 torch 库中导入 F 模块
import torch.nn.functional as F
# 从 torch 库中导入 distributed 模块
import torch.distributed as dist
# 从 torch 库中导入 Function 模块
from torch.autograd import Function
# 从 einops 库中导入 rearrange, repeat 函数
from einops import rearrange, repeat
# 辅助函数
# 判断变量是否存在
def exists(val):
return val is not None
# 如果变量存在则返回该变量,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 判断一个数是否可以被另一个数整除
def divisible_by(numer, denom):
return (numer % denom) == 0
# 分布式
# 在指定维度上对张量进行填充,使其达到指定长度
def pad_dim_to(t, length, dim = 0):
pad_length = length - t.shape[dim]
zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length)
# 对所有进程中的张量进行收集
def all_gather_variable_batch(t):
device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size()
size = torch.tensor(t.shape[0], device = device, dtype = torch.long)
sizes = [torch.empty_like(size, device = device, dtype = torch.long) for i in range(world_size)]
dist.all_gather(sizes, size)
sizes = torch.stack(sizes)
max_size = sizes.amax().item()
padded_t = pad_dim_to(t, max_size, dim = 0)
gathered_tensors = [torch.empty_like(padded_t, device = device, dtype = padded_t.dtype) for i in range(world_size)]
dist.all_gather(gathered_tensors, padded_t)
gathered_tensor = torch.cat(gathered_tensors)
seq = torch.arange(max_size, device = device)
mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1')
mask = rearrange(mask, 'i j -> (i j)')
gathered_tensor = gathered_tensor[mask]
sizes = sizes.tolist()
return gathered_tensor, sizes
# 自定义的 AllGather 函数
class AllGather(Function):
@staticmethod
def forward(ctx, x):
assert dist.is_initialized() and dist.get_world_size() > 1
x, batch_sizes = all_gather_variable_batch(x)
ctx.batch_sizes = batch_sizes
return x
@staticmethod
def backward(ctx, grads):
batch_sizes, rank = ctx.batch_sizes, dist.get_rank()
grads_by_rank = grads.split(batch_sizes, dim = 0)
return grads_by_rank[rank]
# 应用自定义的 AllGather 函数
all_gather = AllGather.apply
# 归一化
# 使用不带偏置的 layernorm,这是 PyTorch 不提供的功能
# Layernorm 类
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
# 残差连接
# Residual 类
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
# 转换为潜变量
# EmbedToLatents 类
class EmbedToLatents(nn.Module):
def __init__(self, dim, dim_latents):
super().__init__()
self.to_latents = nn.Linear(dim, dim_latents, bias=False)
def forward(self, x):
latents = self.to_latents(x)
return F.normalize(latents, dim=-1)
# 旋转位置嵌入
# https://arxiv.org/abs/2104.09864
# RotaryEmbedding 类
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, max_seq_len, *, device):
seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = einsum("i , j -> i j", seq, self.inv_freq)
return torch.cat((freqs, freqs), dim=-1)
# 将张量旋转一半
def rotate_half(x):
x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
# 应用旋转位置嵌入
def apply_rotary_pos_emb(pos, t):
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
# 经典的 Noam Shazeer 论文,这里使用 SwiGLU 代替更流行的 GEGLU 用于门控前馈
# https://arxiv.org/abs/2002.05202
# SwiGLU 类
class SwiGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
# 并行注意力和前馈,带有残差连接
# 定义一个并行Transformer块的类
class ParallelTransformerBlock(nn.Module):
def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
super().__init__()
self.norm = LayerNorm(dim) # 初始化LayerNorm
attn_inner_dim = dim_head * heads
ff_inner_dim = dim * ff_mult
self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2)) # 定义融合维度
self.heads = heads
self.scale = dim_head**-0.5
self.rotary_emb = RotaryEmbedding(dim_head) # 初始化RotaryEmbedding
self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False) # 线性变换
self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False) # 线性变换
self.ff_out = nn.Sequential(
SwiGLU(), # SwiGLU激活函数
nn.Linear(ff_inner_dim, dim, bias=False) # 线性变换
)
# 用于缓存因果掩码和旋转嵌入
self.mask = None
self.pos_emb = None
def get_mask(self, n, device):
if self.mask is not None and self.mask.shape[-1] >= n:
return self.mask[:n, :n].to(device)
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) # 生成上三角掩码
self.mask = mask
return mask
def get_rotary_embedding(self, n, device):
if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
return self.pos_emb[:n].to(device)
pos_emb = self.rotary_emb(n, device=device) # 获取旋转嵌入
self.pos_emb = pos_emb
return pos_emb
def forward(self, x, attn_mask=None):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
n, device, h = x.shape[1], x.device, self.heads
# pre layernorm
x = self.norm(x) # LayerNorm
# attention queries, keys, values, and feedforward inner
q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1) # 拆分线性变换结果
# split heads
q = rearrange(q, "b n (h d) -> b h n d", h=h) # 重排张量形状
# rotary embeddings
positions = self.get_rotary_embedding(n, device) # 获取旋转嵌入
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k)) # 应用旋转嵌入
# scale
q = q * self.scale # 缩放
# similarity
sim = einsum("b h i d, b j d -> b h i j", q, k) # 计算相似度
# causal mask
causal_mask = self.get_mask(n, device) # 获取因果掩码
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) # 应用掩码
# extra attention mask - for masking out attention from text CLS token to padding
if exists(attn_mask):
attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j') # 重排注意力掩码
sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max) # 应用额外的掩码
# attention
attn = sim.softmax(dim=-1) # softmax计算注意力权重
# aggregate values
out = einsum("b h i j, b j d -> b h i d", attn, v) # 聚合值
# merge heads
out = rearrange(out, "b h n d -> b n (h d)") # 合并头部
return self.attn_out(out) + self.ff_out(ff) # 返回注意力输出和前馈输出
# cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward
class CrossAttention(nn.Module):
def __init__(
self,
dim,
*,
context_dim=None,
dim_head=64,
heads=8,
parallel_ff=False,
ff_mult=4,
norm_context=False
):
# 调用父类的初始化方法
super().__init__()
# 初始化头数和缩放因子
self.heads = heads
self.scale = dim_head ** -0.5
# 计算内部维度
inner_dim = heads * dim_head
# 设置上下文维度
context_dim = default(context_dim, dim)
# 初始化 LayerNorm 层
self.norm = LayerNorm(dim)
self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity()
# 初始化线性变换层
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
# 是否有并行前馈
ff_inner_dim = ff_mult * dim
self.ff = nn.Sequential(
nn.Linear(dim, ff_inner_dim * 2, bias=False),
SwiGLU(),
nn.Linear(ff_inner_dim, dim, bias=False)
) if parallel_ff else None
def forward(self, x, context):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
# 预先 LayerNorm,用于查询和上下文
x = self.norm(x)
context = self.context_norm(context)
# 获取查询
q = self.to_q(x)
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
# 缩放
q = q * self.scale
# 获取键/值
k, v = self.to_kv(context).chunk(2, dim=-1)
# 查询/键相似度
sim = einsum('b h i d, b j d -> b h i j', q, k)
# 注意力
attn = sim.softmax(dim=-1)
# 聚合
out = einsum('b h i j, b j d -> b h i d', attn, v)
# 合并和组合头
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
# 添加并行前馈(用于多模态层)
if exists(self.ff):
out = out + self.ff(x)
return out
# 定义一个名为 MaMMUT 的类,继承自 nn.Module 类
class MaMMUT(nn.Module):
# 初始化函数,接收多个参数
def __init__(
self,
*,
dim,
num_tokens,
depth,
cross_attend_every=1,
cross_attend_layers=None,
dim_latents=None,
image_dim=None,
num_img_queries=256,
dim_head=64,
heads=8,
ff_mult=4,
img_encoder=None,
caption_loss_weight=1.,
contrastive_loss_weight=1.,
pad_id=0
):
# 调用父类的初始化函数
super().__init__()
# 初始化类的属性
self.dim = dim
self.pad_id = pad_id
self.caption_loss_weight = caption_loss_weight
self.contrastive_loss_weight = contrastive_loss_weight
# token embeddings
# 创建一个嵌入层,用于将 token 映射为指定维度的向量
self.token_emb = nn.Embedding(num_tokens, dim)
# 创建一个 nn.Parameter 对象,用于存储文本的分类标记
self.text_cls_token = nn.Parameter(torch.randn(dim))
# image encoder
# 设置图像编码器
self.img_encoder = img_encoder
# attention pooling for image tokens
# 创建一个 nn.Parameter 对象,用于存储图像查询向量
self.img_queries = nn.Parameter(torch.randn(num_img_queries + 1, dim)) # num image queries for multimodal, but 1 extra CLS for contrastive learning
# 创建一个交叉注意力池化层,用于处理图像 token
self.img_attn_pool = CrossAttention(dim=dim, context_dim=image_dim, dim_head=dim_head, heads=heads, norm_context=True)
# 创建 LayerNorm 层,用于规范化图像注意力池化结果
self.img_attn_pool_norm = LayerNorm(dim)
# 创建 LayerNorm 层,用于规范化文本分类标记
self.text_cls_norm = LayerNorm(dim)
# to latents
# 设置潜在空间的维度
dim_latents = default(dim_latents, dim)
# 创建将图像嵌入转换为潜在空间的层
self.img_to_latents = EmbedToLatents(dim, dim_latents)
# 创建将文本嵌入转换为潜在空间的层
self.text_to_latents = EmbedToLatents(dim, dim_latents)
# contrastive learning temperature
# 创建一个 nn.Parameter 对象,用于存储对比学习的温度参数
self.temperature = nn.Parameter(torch.Tensor([1.]))
# layers
# 创建一个空的 nn.ModuleList 对象,用于存储多个层
self.layers = nn.ModuleList([])
# 循环创建指定数量的层
for ind in range(depth):
layer = ind + 1
has_cross_attn = divisible_by(layer, cross_attend_every)
if exists(cross_attend_layers):
assert isinstance(cross_attend_layers, tuple)
has_cross_attn = layer in cross_attend_layers
# 将每一层的处理逻辑添加到 layers 中
self.layers.append(nn.ModuleList([
Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)),
Residual(CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult)) if has_cross_attn else None
]))
# to logits
# 创建一个序列,包含规范化层和线性层,用于生成输出 logits
self.to_logits = nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, num_tokens, bias=False)
)
# they used embedding weight tied projection out to logits, not common, but works
# 将线性层的权重与嵌入层的权重绑定在一起
self.to_logits[-1].weight = self.token_emb.weight
# 初始化嵌入层的权重
nn.init.normal_(self.token_emb.weight, std=0.02)
# is data parallel
# 检查是否启用了数据并行处理
self.is_data_parallel = dist.is_initialized() and dist.get_world_size() > 1
# 定义一个方法,用于将文本嵌入
def embed_text(self, text):
# 获取文本的批量大小和设备信息
batch, device = text.shape[0], text.device
seq = text.shape[1]
# 获取文本的 token 嵌入
text_tokens = self.token_emb(text)
# append text cls tokens
# 重复文本分类标记,拼接到文本 token 后面
text_cls_tokens = repeat(self.text_cls_token, 'd -> b 1 d', b=batch)
text_tokens = torch.cat((text_tokens, text_cls_tokens), dim=-2)
# create specific mask for text cls token at the end
# to prevent it from attending to padding
# 创建特定的掩码,用于防止文本分类标记与填充部分进行注意力交互
cls_mask = rearrange(text!=self.pad_id, 'b j -> b 1 j')
attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)
# go through layers, but do not cross attend
# 遍历层,但不进行交叉注意力
for attn_ff, _ in self.layers:
text_tokens = attn_ff(text_tokens, attn_mask=attn_mask)
# get text cls token
# 获取文本分类标记和文本 token
text_tokens, text_cls_tokens = text_tokens[:, :-1], text_tokens[:, -1]
# 规范化文本分类标记
text_embeds = self.text_cls_norm(text_cls_tokens)
return text_embeds, text_tokens
# 将图像嵌入到嵌入向量中
def embed_image(self, images=None, image_tokens=None):
# 将图像编码为嵌入向量
# 使用在初始化时传入的 img_encoder
# 也可以接受预先计算的图像 tokens
# 确保 images 和 image_tokens 不能同时存在
assert not (exists(images) and exists(image_tokens))
if exists(images):
# 确保存在 self.img_encoder,用于自动图像编码
assert exists(self.img_encoder), 'img_encoder must be passed in for automatic image encoding'
image_tokens = self.img_encoder(images)
# 注意力池化图像 tokens
img_queries = repeat(self.img_queries, 'n d -> b n d', b=image_tokens.shape[0])
img_queries = self.img_attn_pool(img_queries, image_tokens)
img_queries = self.img_attn_pool_norm(img_queries)
return img_queries[:, 0], img_queries[:, 1:]
# 前向传播函数
def forward(
self,
text,
text_mask = None,
images=None,
image_tokens=None,
labels=None,
return_loss=False,
return_embeddings=False
):
batch, device = text.shape[0], text.device
if return_loss and not exists(labels):
text, labels = text[:, :-1], text[:, 1:]
text_embeds, _ = self.embed_text(text)
image_embeds, image_tokens = self.embed_image(images=images, image_tokens=image_tokens)
# 如果研究人员需要返回嵌入向量,则返回嵌入向量
if return_embeddings:
return text_embeds, image_embeds
# 经过各层处理
text_tokens = self.token_emb(text)
for attn_ff, cross_attn in self.layers:
text_tokens = attn_ff(text_tokens)
if exists(cross_attn):
text_tokens = cross_attn(text_tokens, image_tokens)
logits = self.to_logits(text_tokens)
if not return_loss:
return logits
# 缩写
ce = F.cross_entropy
# 计算标题损失(交叉熵损失)
logits = rearrange(logits, 'b n c -> b c n')
caption_loss = ce(logits, labels, ignore_index=self.pad_id)
caption_loss = caption_loss * self.caption_loss_weight
# 嵌入到潜变量
text_latents = self.text_to_latents(text_embeds)
image_latents = self.img_to_latents(image_embeds)
# 如果使用数据并行,需要从所有机器中收集所有潜变量
if self.is_data_parallel:
latents = torch.stack((text_latents, image_latents), dim = 1)
latents = all_gather(latents)
text_latents, image_latents = latents.unbind(dim = 1)
# 计算对比损失
sim = einsum('i d, j d -> i j', text_latents, image_latents)
sim = sim * self.temperature.exp()
contrastive_labels = torch.arange(batch, device=device)
contrastive_loss = (ce(sim, contrastive_labels) + ce(sim.t(), contrastive_labels)) * 0.5
contrastive_loss = contrastive_loss * self.contrastive_loss_weight
return caption_loss + contrastive_loss
.\lucidrains\MaMMUT-pytorch\mammut_pytorch\__init__.py
# 从 mammut_pytorch 包中导入 MaMMUT 类
from mammut_pytorch.mammut_pytorch import MaMMUT

MaMMUT - Pytorch
Implementation of MaMMUT, a simple vision-encoder text-decoder architecture for multimodal tasks from Google, in Pytorch. Blog post
This work is basically just a simplified CoCa. I copied the code from this repository and made the change in the paper, which was to simply do two passes through the text encoder, one with cross attention for the generative loss, and the other without for the contrastive loss.
This is also a good time to plug an open sourced version of CoCa from the folks at OpenCLIP!
Appreciation
- Stability and 🤗 Huggingface for their generous sponsorships to work on and open source cutting edge artificial intelligence research
Install
$ pip install mammut-pytorch
Usage
First install the vit-pytorch for the image encoder, which needs to be pretrained
$ pip install vit-pytorch>=0.40.2
Then
import torch
# import vision transformer
from vit_pytorch.simple_vit_with_patch_dropout import SimpleViT
from vit_pytorch.extractor import Extractor
vit = SimpleViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
patch_dropout = 0.5 # https://arxiv.org/abs/2212.00794
)
vit = Extractor(vit, return_embeddings_only = True, detach = False)
# extractor will enable it so the vision transformer returns its embeddings
# import MaMMUT and instantiate it
from mammut_pytorch.mammut_pytorch import MaMMUT
mammut = MaMMUT(
dim = 512, # model dimension
img_encoder = vit, # vision transformer - image encoder, returning image embeddings as (batch, seq, dim)
image_dim = 1024, # image embedding dimension, if not the same as model dimensions
num_tokens = 20000, # number of text tokens
depth = 6, # depth of the transformer
dim_head = 64, # dimension per attention head
heads = 8, # number of attention heads
caption_loss_weight = 1., # weight on the autoregressive caption loss
contrastive_loss_weight = 1., # weight on the contrastive loss between image and text CLS embeddings
).cuda()
# mock text and images
text = torch.randint(0, 20000, (4, 512)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# train by giving MaMMUT your text and images with `return_loss = True`
loss = mammut(
text = text,
images = images,
return_loss = True # set this to True to get the full caption + contrastive loss
)
loss.backward()
# do the above for as much text and images...
# then you can get the caption logits as so
logits = mammut(
text = text,
images = images
) # (4, 512, 20000)
# and the CLIP-like text and image embeddings as
text_embeds, image_embeds = mammut(
text = text,
images = images,
return_embeddings = True
) # (4, 512), (4, 512)
One of the main findings of the paper is that different tasks perform differently depending on the amount of cross attention. This repository will give you full control over how much cross attention you want to place in the network.
mammut = MaMMUT(
dim = 512,
img_encoder = vit,
image_dim = 1024,
num_tokens = 20000,
depth = 6,
cross_attend_every = 2, # say you want to cross attend only every 2 layers
dim_head = 64,
heads = 8,
caption_loss_weight = 1.,
contrastive_loss_weight = 1.
).cuda()
# or you can finely specify which layers to do cross attention
mammut = MaMMUT(
dim = 512,
img_encoder = vit,
image_dim = 1024,
num_tokens = 20000,
depth = 6,
cross_attend_layers = (4, 5, 6), # only last three layers have cross attention
dim_head = 64,
heads = 8,
caption_loss_weight = 1.,
contrastive_loss_weight = 1.
).cuda()
Todo
Citations
@article{Kuo2023MaMMUTAS,
title = {MaMMUT: A Simple Architecture for Joint Learning for MultiModal Tasks},
author = {Weicheng Kuo and A. J. Piergiovanni and Dahun Kim and Xiyang Luo and Benjamin Caine and W. Li and Abhijit S. Ogale and Luowei Zhou and Andrew M. Dai and Zhifeng Chen and Claire Cui and Anelia Angelova},
journal = {ArXiv},
year = {2023},
volume = {abs/2303.16839}
}
@inproceedings{Chowdhery2022PaLMSL,
title = {PaLM: Scaling Language Modeling with Pathways},
author = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
year = {2022}
}
.\lucidrains\MaMMUT-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'MaMMUT-pytorch', # 包的名称
packages = find_packages(exclude=[]), # 查找并包含所有包
version = '0.0.7', # 版本号
license='MIT', # 许可证
description = 'MaMMUT - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/MaMMUT-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'multimodal',
'attention mechanism',
'contrastive learning'
],
install_requires=[ # 安装依赖
'einops>=0.6.1',
'torch>=1.6',
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\marge-pytorch\marge_pytorch\autoregressive_wrapper.py
from functools import partial
import torch
import random
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
# 定义一个函数,返回 value 或者 default 中的值
def default(value, default):
return value if value is not None else default
# 对输入张量取对数,加上一个很小的值 eps 防止出现取对数时的错误
def log(t, eps=1e-9):
return torch.log(t + eps)
# 根据 top-p 策略过滤 logits
def top_p(logits, thres = 0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > 1.0 - thres
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0
sorted_logits[sorted_indices_to_remove] = float('-inf')
return sorted_logits.scatter(1, sorted_indices, sorted_logits)
# 根据 top-k 策略过滤 logits
def top_k(logits, thres = 0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
# 自回归包装器类
class AutoregressiveWrapper(nn.Module):
def __init__(self, net, ignore_index = None, pad_value = 0):
super().__init__()
self.pad_value = pad_value
self.ignore_index = default(ignore_index, pad_value)
self.net = net
self.max_seq_len = net.max_seq_len
# 生成序列
@torch.no_grad()
def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
was_training = self.net.training
num_dims = len(start_tokens.shape)
if num_dims == 1:
start_tokens = start_tokens[None, :]
b, t = start_tokens.shape
self.net.eval()
out = start_tokens
input_mask = kwargs.pop('src_mask', None)
if input_mask is None:
input_mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)
for _ in range(seq_len):
x = out[:, -self.max_seq_len:]
input_mask = input_mask[:, -self.max_seq_len:]
logits, *_ = self.net(x, src_mask=input_mask, **kwargs)
logits = logits[:, -1, :]
filtered_logits = filter_logits_fn(logits, thres = filter_thres)
gumbel_noise = -log(-log(torch.zeros_like(filtered_logits).uniform_(0, 1)))
sample = ((filtered_logits / temperature) + gumbel_noise).argmax(dim=-1)
out = torch.cat((out, sample[:, None]), dim=-1)
input_mask = F.pad(input_mask, (1, 0), value=True)
if eos_token is not None and (sample == eos_token).all():
break
out = out[:, t:]
if num_dims == 1:
out = out.squeeze(0)
self.net.train(was_training)
return out
# 前向传播函数
def forward(self, x, *args, **kwargs):
pad = partial(pad_sequence, batch_first = True, padding_value = self.pad_value)
m = kwargs.pop('input_mask', None)
xi, xo = x[:, :-1], x[:, 1:]
if m is not None:
assert m.shape == x.shape[0:2], 'input mask must be the same shape as the input of the auto-regressive wrapper to automatically handle'
kwargs.update(input_mask = m[:, :-1])
out, *rest = self.net(xi, *args, **kwargs)
loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
return (loss, *rest)
.\lucidrains\marge-pytorch\marge_pytorch\marge_pytorch.py
# 导入必要的库
import faiss
import math
import numpy as np
from tqdm import tqdm
from einops import rearrange, repeat
from functools import partial
from contextlib import contextmanager
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn, einsum
import torch.nn.functional as F
from marge_pytorch.autoregressive_wrapper import AutoregressiveWrapper
# 定义一些辅助函数
# 返回输入值
def identity(x, *args, **kwargs):
return x
# 检查输入值是否存在
def exists(x):
return x is not None
# 如果输入值存在则返回输入值,否则返回默认值
def default(x, d):
return x if exists(x) else d
# 将列表分块
def chunk(chunk_size, l):
for lo in range(0, l, chunk_size):
hi = min(l, lo + chunk_size)
yield slice(lo, hi)
# 返回输入张量的最大负值
def max_neg_value(tensor):
return -torch.finfo(tensor.dtype).max
# 上下文管理器,用于创建内存映射
@contextmanager
def memmap(*args, **kwargs):
pointer = np.memmap(*args, **kwargs)
yield pointer
del pointer
# 注意力蒸馏损失函数
def distill_attn_loss(evi_dots, doc_similarities, mask = None, eps = 1e-5):
evi_dots = rearrange(evi_dots, 'b l h i n j -> b (l h i) n j')
if exists(mask):
mask = rearrange(mask, 'b n j -> b () n j')
evi_dots.masked_fill_(~mask, 0.)
denom = mask.expand_as(evi_dots).sum(dim = (1, -1))
evi_dots_mean = evi_dots.sum(dim = (1, -1)) / (denom + eps)
else:
evi_dots_mean = evi_dots.mean(dim = (1, -1))
normed_evi_dots = evi_dots_mean.softmax(dim = -1)
normed_evi_dots.detach_()
doc_similarities = doc_similarities.softmax(dim = -1).log()
loss = F.kl_div(doc_similarities, normed_evi_dots, reduction = 'batchmean')
return loss
# 辅助类
# 带有 LayerNorm 的预正规化
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, *args, **kwargs):
x = self.norm(x)
return self.fn(x, *args, **kwargs)
# GEGLU 激活函数
class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim = -1)
return F.gelu(gates) * x
# 前馈神经网络
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
# 为了保持参数数量/计算量与非 GLU 变体相对恒定
mult = int(mult / 3 * 2)
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
)
def forward(self, x):
return self.net(x)
# 自注意力机制
class SelfAttention(nn.Module):
def __init__(self, dim, heads = 8, causal = True, dropout = 0.):
super().__init__()
self.scale = dim ** -0.5
self.heads = heads
self.causal = causal
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.to_out = nn.Linear(dim, dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask = None):
_, n, _, h, device = *x.shape, self.heads, x.device
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', h = h, qkv = 3)
dots = einsum('bhid,bhjd->bhij', q, k) * self.scale
mask_value = max_neg_value(dots)
if exists(mask):
mask = mask[:, None, :, None] * mask[:, None, None, :]
dots.masked_fill_(~mask, mask_value)
del mask
if self.causal:
causal_mask = torch.ones(n, n, device=device).triu_(1).bool()
dots.masked_fill_(causal_mask, mask_value)
del causal_mask
attn = dots.softmax(dim=-1)
attn = self.dropout(attn)
out = einsum('bhij,bhjd->bhid', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
class CrossAttention(nn.Module):
# 初始化函数,设置注意力机制的参数
def __init__(self, dim, heads = 8, dropout = 0.):
# 调用父类的初始化函数
super().__init__()
# 计算缩放因子
self.scale = dim ** -0.5
# 设置头数
self.heads = heads
# 线性变换,将输入转换为查询向量
self.to_q = nn.Linear(dim, dim, bias = False)
# 线性变换,将输入转换为键值对
self.to_kv = nn.Linear(dim, dim * 2, bias = False)
# 初始化可学习参数 beta
self.beta = nn.Parameter(torch.tensor(1.), requires_grad=True)
# 线性变换,将输出转换为最终输出
self.to_out = nn.Linear(dim, dim)
# Dropout 层,用于防止过拟合
self.dropout = nn.Dropout(dropout)
# 前向传播函数
def forward(self, x, context, doc_similarities, mask = None, context_mask = None):
# 获取输入 x 的形状信息
b, n, _, h, device = *x.shape, self.heads, x.device
# 将输入 x 转换为查询向量 q
q = self.to_q(x)
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
# 重排上下文信息 context 的形状
context_len = context.shape[2]
context = rearrange(context, 'b m n d -> b (m n) d')
context_mask = rearrange(context_mask, 'b m n -> b (m n)') if exists(context_mask) else None
# 重复文档相似度信息 doc_similarities
doc_similarities = repeat(doc_similarities, 'b m -> b m n', n=context_len)
doc_similarities = rearrange(doc_similarities, 'b m n -> b (m n)')
doc_similarities = doc_similarities[:, None, None, :] * self.beta
# 将上下文信息 context 转换为键值对 k, v
kv = self.to_kv(context)
k, v = rearrange(kv, 'b n (kv h d) -> kv b h n d', h = h, kv = 2)
# 计算注意力分数
dots = einsum('bhid,bhjd->bhij', q, k) * self.scale
pre_attn_dots = dots
# 添加文档相似度信息到注意力分数
dots = dots + doc_similarities
# 处理掩码信息
if any(map(exists, (mask, context_mask))):
if not exists(mask):
mask = torch.full((b, n), True, dtype=torch.bool, device=device)
if not exists(context_mask):
context_mask = torch.full(context.shape[:2], True, dtype=torch.bool, device=device)
cross_mask = mask[:, None, :, None] * context_mask[:, None, None, :]
mask_value = max_neg_value(dots)
dots.masked_fill_(~cross_mask, mask_value)
del cross_mask
# 计算注意力权重
attn = dots.softmax(dim=-1)
attn = self.dropout(attn)
# 计算输出
out = einsum('bhij,bhjd->bhid', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out, pre_attn_dots
class Encoder(nn.Module):
def __init__(self, dim, depth, retrieval_depth = 4, heads = 8, ff_mult = 4, attn_dropout = 0., ff_dropout = 0.):
super().__init__()
assert depth > retrieval_depth, f'Depth must be at least the depth set for the retrieval encoder ({retrieval_depth})'
# 定义一个 lambda 函数,用于创建包含 SelfAttention 和 FeedForward 的模块列表
block = lambda: nn.ModuleList([
PreNorm(dim, SelfAttention(dim, causal=False, dropout = attn_dropout)),
PreNorm(dim, FeedForward(dim, mult = ff_mult))
])
# 初始化模型参数
self.cls = nn.Parameter(torch.zeros(1, dim), requires_grad=True)
self.encoder_head = nn.ModuleList([])
self.encoder_tail = nn.ModuleList([])
# 创建 retrieval_depth 个 encoder_head 模块
for _ in range(retrieval_depth):
self.encoder_head.append(block())
# 创建 depth - retrieval_depth 个 encoder_tail 模块
for _ in range(depth - retrieval_depth):
self.encoder_tail.append(block())
def forward(self, x, src_mask = None, return_embed_only = False):
b, _, _ = x.shape
# 添加 cls token
cls_token = repeat(self.cls, 'n d -> b n d', b=b)
x = torch.cat((cls_token, x), dim=1)
src_mask = F.pad(src_mask, (1, 0), value=True) if exists(src_mask) else None
# 对 encoder_head 中的模块进行前向传播
for attn, ff in self.encoder_head:
x = attn(x, mask = src_mask) + x
x = ff(x) + x
cls_tokens = x[:, 0]
if return_embed_only:
return cls_tokens, None
# 对 encoder_tail 中的模块进行前向传播
for attn, ff in self.encoder_tail:
x = attn(x, mask = src_mask) + x
x = ff(x) + x
return x[:, 1:], cls_tokens
class Decoder(nn.Module):
def __init__(self, dim, depth, head_depth = 4, heads = 8, ff_mult = 4, attn_dropout = 0., ff_dropout = 0.):
super().__init__()
self.decoder_head = nn.ModuleList([])
self.decoder_tail = nn.ModuleList([])
# 创建 head_depth 个 decoder_head 模块
for _ in range(head_depth):
self.decoder_head.append(nn.ModuleList([
PreNorm(dim, SelfAttention(dim, causal = True, dropout = attn_dropout)),
PreNorm(dim, FeedForward(dim))
]))
# 创建 depth - head_depth 个 decoder_tail 模块
for _ in range(depth - head_depth):
self.decoder_tail.append(nn.ModuleList([
PreNorm(dim, SelfAttention(dim, causal = True, dropout = attn_dropout)),
PreNorm(dim, FeedForward(dim)),
PreNorm(dim, CrossAttention(dim, dropout = attn_dropout)),
PreNorm(dim, FeedForward(dim, mult = ff_mult))
]))
def forward(self, x, *, context, similarities, src_mask = None, context_mask = None):
# 对 decoder_head 中的模块进行前向传播
for self_attn, self_ff in self.decoder_head:
x = self_attn(x, mask = src_mask) + x
x = self_ff(x) + x
cross_pre_attns = []
# 对 decoder_tail 中的模块进行前向传播
for self_attn, self_ff, cross_attn, cross_ff in self.decoder_tail:
x = self_attn(x, mask = src_mask) + x
x = self_ff(x) + x
x_out, attn = cross_attn(x, context, similarities, mask = src_mask, context_mask = context_mask)
x = x_out + x
x = cross_ff(x) + x
cross_pre_attns.append(attn)
return x, cross_pre_attns
class TransformerWrapper(nn.Module):
def __init__(self, num_tokens, dim, max_seq_len, layers, return_logits = False):
super().__init__()
self.token_emb = nn.Embedding(num_tokens, dim)
self.pos_emb = nn.Embedding(max_seq_len, dim)
self.max_seq_len = max_seq_len
self.layers = layers
self.to_logits = nn.Linear(dim, num_tokens) if return_logits else identity
def forward(self, x, *args, **kwargs):
b, n, device = *x.shape, x.device
assert n <= self.max_seq_len, f'your sequence length {n} needs to be less than or equal to the max sequence length {self.max_seq_len}'
x = self.token_emb(x)
x += self.pos_emb(torch.arange(n, device=device))
x, *out = self.layers(x, *args, **kwargs)
return (self.to_logits(x), *out)
class Marge(nn.Module):
# 初始化函数,设置模型参数
def __init__(
self,
dim,
num_tokens = 20000,
max_seq_len = 1024,
enc_depth = 12,
enc_retrieval_depth = 4,
enc_heads = 8,
enc_ff_mult = 4,
enc_attn_dropout = 0.,
enc_ff_dropout = 0.,
dec_depth = 12,
dec_heads = 8,
dec_ff_mult = 16,
dec_attn_dropout = 0.,
dec_ff_dropout = 0.,
distill_attn = False,
distill_loss_coef = 1.
):
# 调用父类的初始化函数
super().__init__()
# 设置模型维度
self.dim = dim
# 创建编码器和解码器对象
self.encoder = TransformerWrapper(num_tokens, dim, max_seq_len, Encoder(dim, depth = enc_depth, retrieval_depth = enc_retrieval_depth, heads = enc_heads, ff_mult = enc_ff_mult, attn_dropout = enc_attn_dropout, ff_dropout = enc_ff_dropout))
self.decoder = TransformerWrapper(num_tokens, dim, max_seq_len, Decoder(dim, depth = dec_depth, heads = dec_heads, ff_mult = dec_ff_mult, attn_dropout = dec_attn_dropout, ff_dropout = dec_ff_dropout), return_logits = True)
# 共享编码器和解码器的词嵌入层
self.encoder.token_emb = self.decoder.token_emb
# 将解码器包装为自回归模型
self.decoder = AutoregressiveWrapper(self.decoder)
# 实验性的注意力蒸馏设置
self.distill_attn = distill_attn
self.distill_loss_coef = distill_loss_coef
# 获取文档的嵌入表示
def get_embeds(self, documents, batch_size = 16, masks = None):
embeds = []
# 将文档分成批次
batched_documents = documents.split(batch_size)
batched_masks = masks.split(batch_size) if exists(masks) else ([None] * len(batched_documents))
# 对每个批次的文档计算嵌入表示
for docs, mask in zip(batched_documents, batched_masks):
embed, *_ = self.encoder(docs, src_mask = mask, return_embed_only = True)
embeds.append(embed)
# 拼接所有嵌入表示并进行归一化
embeds = torch.cat(embeds)
return F.normalize(embeds, dim=-1)
# 生成文本序列
@torch.no_grad()
def generate(self, prime, seq_len, evidence, mask = None, similarities = None):
b, num_evidences, *_ = evidence.shape
evidence = rearrange(evidence, 'b m n -> (b m) n')
enc_src_mask = rearrange(mask, 'b m n -> (b m) n') if exists(mask) else None
# 编码证据文本
encodings, evidence_embeds = self.encoder(evidence, src_mask = enc_src_mask)
encodings = rearrange(encodings, '(b m) n d -> b m n d', m = num_evidences)
# 计算相似度
similarities = similarities if exists(similarities) else torch.ones((b, num_evidences)).float().cuda()
context_mask = F.pad(mask, (1, 0), value = True) if exists(mask) else None
return self.decoder.generate(prime, seq_len, context = encodings, similarities = similarities, context_mask = context_mask)
# 前向传播函数
def forward(self, evidence, target, target_embeds, src_mask = None, tgt_mask = None):
num_evidences = evidence.shape[1]
evidence = rearrange(evidence, 'b m n -> (b m) n')
enc_src_mask = rearrange(src_mask, 'b m n -> (b m) n') if exists(src_mask) else None
encodings, evidence_embeds = self.encoder(evidence, src_mask = enc_src_mask)
encodings = rearrange(encodings, '(b m) n d -> b m n d', m = num_evidences)
evidence_embeds = rearrange(evidence_embeds, '(b m) d -> b m d', m = num_evidences)
# 计算相似度
similarities = einsum('bmd,bd->bm', evidence_embeds, target_embeds)
dec_src_mask = tgt_mask[:, :-1] if exists(tgt_mask) else None
# 计算损失和交叉注意力
loss, cross_attns = self.decoder(target, context = encodings, similarities = similarities, src_mask = dec_src_mask, context_mask = src_mask)
# 如果开启了注意力蒸馏
if self.distill_attn:
cross_attns = torch.stack(cross_attns, dim = 1)
cross_attns = rearrange(cross_attns, 'b l h i (n j) -> b l h i n j', n = num_evidences)
distill_loss = distill_attn_loss(cross_attns, similarities, mask = src_mask)
aux_loss = self.distill_loss_coef * distill_loss
loss = loss + aux_loss
return loss
# training related classes
# 从证据中移除目标
def remove_target_from_evidence(evidence_ids, target_ids):
b, n = evidence_ids.shape
# 创建匹配掩码,标记证据中是否存在目标
match_mask = evidence_ids == target_ids[:, None]
# 创建行没有匹配项的掩码
rows_without_matches = (match_mask.sum(axis=-1) == 0)[:, None]
# 创建需要移除的掩码
remove_mask = np.concatenate((np.full((b, n - 1), False), rows_without_matches), axis=1)
# 合并匹配掩码和移除掩码
mask = match_mask + remove_mask
# 过滤掉匹配和需要移除的证据
filtered_ids = evidence_ids[~mask]
return filtered_ids.reshape(b, n - 1)
# 文档数据集类
class DocumentDataset(Dataset):
def __init__(self, num_docs, doc_seq_len, num_evidences, documents_path, masks_path, num_targets, target_seq_len, target_path, target_masks_path):
super().__init__()
self.shape = (num_docs, doc_seq_len)
self.target_shape = (num_targets, target_seq_len)
self.knn_shape = (num_targets, num_evidences)
self.documents = np.memmap(documents_path, dtype=np.int32, shape=self.shape)
self.targets = np.memmap(target_path, dtype=np.int32, shape=self.target_shape)
self.masks = np.memmap(masks_path, dtype=np.bool, shape=self.shape) if exists(masks_path) else None
self.target_masks = np.memmap(target_masks_path, dtype=np.bool, shape=self.target_shape) if exists(target_masks_path) else None
self.knn = None
# 设置最近邻路径
def set_knn_path(self, path):
if exists(self.knn):
del self.knn
self.knn = np.memmap(path, dtype=np.int32, shape=self.knn_shape)
def __len__(self):
return self.target_shape[0]
def __getitem__(self, ind):
assert exists(self.knn), 'The memmap path to the generated k nearest neighbors for evidences must be set for the dataset'
target_data = torch.from_numpy(self.targets[ind, :]).long()
target_masks = torch.from_numpy(self.target_masks[ind, :]) if exists(self.target_masks) else torch.ones_like(target_data).bool()
evidence_ids = self.knn[ind, :]
evidence_data = torch.from_numpy(self.documents[evidence_ids, :]).long()
evidence_masks = torch.from_numpy(self.masks[evidence_ids, :]) if exists(self.masks) else torch.ones_like(evidence_data).bool()
return target_data.cuda(), target_masks.cuda(), evidence_data.cuda(), evidence_masks.cuda()
# FaissANN 类
class FaissANN():
def __init__(
self,
dim,
num_documents,
num_subvectors = 16,
hnsw_m = 32,
nbits = 8
):
super().__init__()
nlist = math.floor(math.sqrt(num_documents))
quantizer = faiss.IndexHNSWFlat(dim, hnsw_m)
index = faiss.IndexIVFPQ(quantizer, dim, nlist, num_subvectors, nbits)
self.index = faiss.index_cpu_to_all_gpus(index)
self.num_training = max(nlist * 10, 256)
def reset(self):
return self.index.reset()
def train(self, x):
return self.index.train(x)
def add(self, x):
return self.index.add(x)
def search(self, x, topk, nprobe=8):
self.index.nprobe = nprobe
return self.index.search(x, k=topk)
# 训练包装类
class TrainingWrapper(nn.Module):
def __init__(
self,
model,
*,
num_documents,
doc_seq_len,
documents_memmap_path,
masks_memmap_path = None,
num_targets = None,
target_seq_len = None,
target_memmap_path = None,
target_masks_memmap_path = None,
num_evidence = 4,
reindex_batch_size = 4,
use_faiss_ann = False
# 初始化函数,继承父类的初始化方法
def __init__(
self,
model,
num_documents,
doc_seq_len,
documents_memmap_path,
num_evidence,
num_targets=None,
target_memmap_path=None,
target_masks_memmap_path=None,
target_seq_len=None,
use_faiss_ann=False,
reindex_batch_size=1000
):
# 调用父类的初始化方法
super().__init__()
# 设置模型的维度和证据数量
self.dim = model.dim
self.num_evidence = num_evidence
# 将模型移到 GPU 上
self.model = model.cuda()
self.num_docs = num_documents
# 设置目标数量,默认为文档数量
num_targets = default(num_targets, num_documents)
self.num_targets = num_targets
# 设置文档的形状
self.doc_shape = (num_documents, doc_seq_len)
# 设置文档路径和是否分开目标和证据
self.documents_path = documents_memmap_path
self.separate_target_and_evidence = exists(target_memmap_path)
# 如果分开目标和证据
if self.separate_target_and_evidence:
assert exists(num_targets), 'number of target documents must be defined if target document set is different than evidence document set'
assert exists(target_seq_len), 'target sequence length must be specified'
else:
# 否则设置目标路径和序列长度
target_memmap_path = default(target_memmap_path, documents_memmap_path)
target_masks_memmap_path = default(target_masks_memmap_path, masks_memmap_path)
target_seq_len = default(target_seq_len, doc_seq_len)
# 设置目标的形状和路径
self.target_shape = (num_targets, target_seq_len)
self.target_path = target_memmap_path
self.knn_path = f'{self.documents_path}.knn'
# 设置是否使用 Faiss 近似最近邻搜索
self.use_faiss_ann = use_faiss_ann
if use_faiss_ann:
self.index = FaissANN(self.dim, self.num_docs)
else:
index = faiss.IndexFlatL2(self.dim)
self.index = faiss.index_cpu_to_all_gpus(index)
# 设置重新索引的批量大小并重新索引
self.reindex_batch_size = reindex_batch_size
self.reindex()
# 创建数据集
self.dataset = DocumentDataset(
num_documents,
doc_seq_len,
num_evidence,
documents_memmap_path,
masks_memmap_path,
num_targets,
target_seq_len,
target_memmap_path,
target_masks_memmap_path
)
# 设置数据集的 KNN 路径
self.dataset.set_knn_path(self.knn_path)
# 获取数据集的方法
def get_dataset(self):
return self.dataset
# 禁用梯度计算
@torch.no_grad()
# 重新索引方法,用于更新索引
def reindex(self):
# 设置批处理大小
batch_size = self.reindex_batch_size
# 定义获取嵌入向量的函数
def get_embeds(data):
# 获取模型的嵌入向量并转换为 NumPy 数组
embeds = self.model.get_embeds(data, batch_size=batch_size)
return embeds.detach().cpu().numpy()
# 使用内存映射打开文档路径、目标路径和最近邻路径
with memmap(self.documents_path, dtype=np.int32, shape=self.doc_shape) as (doc_pointer
), memmap(self.target_path, dtype=np.int32, shape=self.target_shape) as (target_pointer
), memmap(self.knn_path, dtype=np.int32, shape=(self.num_docs, self.num_evidence), mode='w+') as knn_writer:
# 如果使用 Faiss 近似最近邻搜索
if self.use_faiss_ann:
# 随机选择部分文档进行训练
random_indices = np.random.permutation(self.num_docs)[:self.index.num_training]
np_data = torch.from_numpy(doc_pointer[random_indices]).cuda().long()
train_embeds = get_embeds(np_data)
# 训练索引
self.index.train(train_embeds)
# 计算总的文档块数
total_evidence_chunks = math.ceil(self.num_docs / batch_size)
# 遍历文档数据块,将嵌入向量添加到索引中
for data_slice in tqdm(chunk(batch_size, self.num_docs), total=total_evidence_chunks, desc='Adding embedding to indexes'):
np_data = torch.from_numpy(doc_pointer[data_slice, :]).cuda().long()
embeds = get_embeds(np_data)
self.index.add(embeds)
# 计算总的目标块数
total_target_chunks = math.ceil(self.num_targets / batch_size)
# 遍历目标数据块,获取并存储最近邻
for data_slice in tqdm(chunk(batch_size, self.num_targets), total=total_target_chunks, desc='Fetching and storing nearest neighbors'):
np_data = torch.from_numpy(target_pointer[data_slice, :]).cuda().long()
embeds = get_embeds(np_data)
fetch_num_evidences = self.num_evidence + (0 if self.separate_target_and_evidence else 1)
# 搜索最近邻
_, evidence_ids = self.index.search(embeds, fetch_num_evidences)
target_ids = np.arange(data_slice.start, data_slice.stop)
# 如果不分离目标和证据
if not self.separate_target_and_evidence:
evidence_ids = remove_target_from_evidence(evidence_ids, target_ids)
# 将最近邻写入内存映射
knn_writer[data_slice, :] = evidence_ids
# 重置索引
self.index.reset()
# 打印重新索引完成信息
print('reindexing complete')
# 前向传播方法,用于计算损失
def forward(self, data):
# 解析输入数据
targets, target_masks, evidences, evidence_masks = data
# 获取目标嵌入向量
target_embeds = self.model.get_embeds(targets, masks=target_masks)
# 计算损失
loss = self.model(evidences, targets, target_embeds, src_mask=evidence_masks, tgt_mask=target_masks)
return loss
.\lucidrains\marge-pytorch\marge_pytorch\__init__.py
# 从 marge_pytorch 模块中导入 Marge 和 TrainingWrapper 类
# 从 marge_pytorch 模块中导入 AutoregressiveWrapper 类
from marge_pytorch.marge_pytorch import Marge, TrainingWrapper
from marge_pytorch.autoregressive_wrapper import AutoregressiveWrapper

Marge - Pre-training via Paraphrasing
Implementation of Marge, Pre-training via Paraphrasing, in Pytorch. It is an alternative to masked language modeling pretraining, where an encoder / decoder attention network learns to reconstruct a target document from a collection of evidence documents.
Update: Three researchers have independently reported that the repository works for them
Install
$ pip install marge-pytorch
Usage
import torch
import numpy as np
from torch.utils.data import DataLoader
from marge_pytorch import Marge, TrainingWrapper
# your documents must be tokenized and stored as memmap in the shape (num documents, seq length)
# constants
NUM_DOCS = 10000
SEQ_LEN = 1024
SHAPE = (NUM_DOCS, SEQ_LEN)
# generate mock training data
f = np.memmap('./train.dat', dtype=np.int32, mode='w+', shape=SHAPE)
f[:] = np.random.randint(0, 20000, size=SHAPE)
del f
# generate mock masking data
f = np.memmap('./train.mask.dat', dtype=np.bool, mode='w+', shape=SHAPE)
f[:] = np.full(SHAPE, True)
del f
# instantiate model
model = Marge(
dim = 512,
num_tokens = 20000,
max_seq_len = SEQ_LEN,
enc_depth = 12,
enc_retrieval_depth = 4, # defaults to 4 as in paper (take the CLS token after the 4th layer of the encoder)
enc_heads = 8,
enc_ff_mult = 4,
dec_depth = 12,
dec_heads = 8,
dec_ff_mult = 16, # paper noted that decoder needs to have much bigger feed forward sizes
distill_attn = False, # (experimental) will add, on top of the decoder loss, an auxiliary distillation loss as defined in https://arxiv.org/abs/2012.04584
distill_loss_coef = 1. # weight of distillation auxilliary loss
)
# wrap your model and your documents
trainer = TrainingWrapper(
model,
num_documents = NUM_DOCS,
doc_seq_len = SEQ_LEN,
num_evidence = 4, # number of evidence documents to fetch per target document to construct
reindex_batch_size = 32, # batch size to use when reindexing
documents_memmap_path = './train.dat', # path to the mem-mapped documents
masks_memmap_path = './train.mask.dat', # if None is supplied, will assume all tokens are visible
use_faiss_ann = True # set this to false if you have a low number of documents, and approximate nearest neighbor is not needed
)
# instantiate dataloader
dl = DataLoader(trainer.dataset, batch_size=16)
# now you can train, and use the reindex method on the training wrapper at appropriate intervals
for ind, data in enumerate(dl):
loss = trainer(data)
loss.backward()
# optimizer step and all that
# reindex and precompute knn every 10000 steps, as in paper
if ind > 0 and ind % 10000 == 0:
trainer.reindex()
Save your model after much training
torch.save(model, f'./trained-model.pt')
Advanced
If you would like the target and evidence documents to be from different sets, you just have to pass in up to four additional keyword arguments, as shown below.
trainer = TrainingWrapper(
model,
num_documents = NUM_DOCS,
doc_seq_len = SEQ_LEN,
num_evidence = 4,
reindex_batch_size = 32,
documents_memmap_path = './evidence.dat',
masks_memmap_path = './evidence.mask.dat',
num_targets = NUM_TARGETS, # 1. number of target documents, with sequence length the same as the document (evidence)
target_seq_len = SEQ_LEN, # 2. sequence length of target documents
target_memmap_path = './target.dat', # 3. path to target memmap, same as documents (evidence)
target_masks_memmap_path = './target.mask.dat', # 4. path to target mask memmap, same as document masks (evidence)
use_faiss_ann = True
)
Sampling
You can sample from the decoder with the following instructions
# some random evidence from the dataset
# or provide your own in the dimensions (b x num_evidences x seq_len)
*_, evidence, mask = trainer.dataset[0:1]
# assume 1 is start token
prime = torch.tensor([[1.]]).long().cuda()
# supply your own document similarities array (b x num_evidences)
# if not supplied, will default to 1. for all evidence
doc_similarities = torch.ones(evidence.shape[:2]).float().cuda()
# generate sample of length 1024
samples = model.generate(prime, 1024, evidence, mask = mask, similarities = doc_similarities)
Citations
@misc{lewis2020pretraining,
title={Pre-training via Paraphrasing},
author={Mike Lewis and Marjan Ghazvininejad and Gargi Ghosh and Armen Aghajanyan and Sida Wang and Luke Zettlemoyer},
year={2020},
eprint={2006.15020},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
@misc{komatsuzaki2020current,
title={Current Limitations of Language Models: What You Need is Retrieval},
author={Aran Komatsuzaki},
year={2020},
eprint={2009.06857},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
@misc{izacard2020distilling,
title={Distilling Knowledge from Reader to Retriever for Question Answering},
author={Gautier Izacard and Edouard Grave},
year={2020},
eprint={2012.04584},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
.\lucidrains\marge-pytorch\setup.py
# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 设置包的信息
setup(
name = 'marge-pytorch', # 包名
packages = find_packages(), # 查找所有包
version = '0.2.9', # 版本号
license='MIT', # 许可证
description = 'Marge - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/marge-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'attention mechanism',
'transformers',
'pre-training'
],
install_requires=[ # 安装依赖
'einops>=0.3',
'faiss-gpu',
'numpy',
'torch>=1.6',
'tqdm'
],
classifiers=[ # 分类器列表
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\med-seg-diff-pytorch\driver.py
import os
import argparse
from tqdm import tqdm
import torch
import numpy as np
import torchvision.transforms as transforms
from torch.optim import AdamW
from lion_pytorch import Lion
from med_seg_diff_pytorch import Unet, MedSegDiff
from med_seg_diff_pytorch.dataset import ISICDataset, GenericNpyDataset
from accelerate import Accelerator
import wandb
## Parse CLI arguments ##
def parse_args():
# 创建参数解析器
parser = argparse.ArgumentParser()
# 添加参数选项
parser.add_argument('-slr', '--scale_lr', action='store_true', help="Whether to scale lr.")
parser.add_argument('-rt', '--report_to', type=str, default="wandb", choices=["wandb"],
help="Where to log to. Currently only supports wandb")
parser.add_argument('-ld', '--logging_dir', type=str, default="logs", help="Logging dir.")
parser.add_argument('-od', '--output_dir', type=str, default="output", help="Output dir.")
parser.add_argument('-mp', '--mixed_precision', type=str, default="no", choices=["no", "fp16", "bf16"],
help="Whether to do mixed precision")
parser.add_argument('-ga', '--gradient_accumulation_steps', type=int, default=4,
help="The number of gradient accumulation steps.")
parser.add_argument('-img', '--img_folder', type=str, default='ISBI2016_ISIC_Part3B_Training_Data',
help='The image file path from data_path')
parser.add_argument('-csv', '--csv_file', type=str, default='ISBI2016_ISIC_Part3B_Training_GroundTruth.csv',
help='The csv file to load in from data_path')
parser.add_argument('-sc', '--self_condition', action='store_true', help='Whether to do self condition')
parser.add_argument('-lr', '--learning_rate', type=float, default=5e-4, help='learning rate')
parser.add_argument('-ab1', '--adam_beta1', type=float, default=0.95,
help='The beta1 parameter for the Adam optimizer.')
parser.add_argument('-ab2', '--adam_beta2', type=float, default=0.999,
help='The beta2 parameter for the Adam optimizer.')
parser.add_argument('-aw', '--adam_weight_decay', type=float, default=1e-6,
help='Weight decay magnitude for the Adam optimizer.')
parser.add_argument('-ae', '--adam_epsilon', type=float, default=1e-08,
help='Epsilon value for the Adam optimizer.')
parser.add_argument('-ul', '--use_lion', type=bool, default=False, help='use Lion optimizer')
parser.add_argument('-ic', '--mask_channels', type=int, default=1, help='input channels for training (default: 3)')
parser.add_argument('-c', '--input_img_channels', type=int, default=3,
help='output channels for training (default: 3)')
parser.add_argument('-is', '--image_size', type=int, default=128, help='input image size (default: 128)')
parser.add_argument('-dd', '--data_path', default='./data', help='directory of input image')
parser.add_argument('-d', '--dim', type=int, default=64, help='dim (default: 64)')
parser.add_argument('-e', '--epochs', type=int, default=10000, help='number of epochs (default: 10000)')
parser.add_argument('-bs', '--batch_size', type=int, default=8, help='batch size to train on (default: 8)')
parser.add_argument('--timesteps', type=int, default=1000, help='number of timesteps (default: 1000)')
parser.add_argument('-ds', '--dataset', default='generic', help='Dataset to use')
parser.add_argument('--save_every', type=int, default=100, help='save_every n epochs (default: 100)')
parser.add_argument('--load_model_from', default=None, help='path to pt file to load from')
# 解析参数并返回
return parser.parse_args()
def load_data(args):
# 加载数据集
# 如果数据集为ISIC,则定义ISIC数据集的转换操作列表
if args.dataset == 'ISIC':
transform_list = [transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(), ]
transform_train = transforms.Compose(transform_list)
# 创建ISIC数据集对象
dataset = ISICDataset(args.data_path, args.csv_file, args.img_folder, transform=transform_train, training=True,
flip_p=0.5)
# 如果数据集为generic,则定义generic数据集的转换操作列表
elif args.dataset == 'generic':
transform_list = [transforms.ToPILImage(), transforms.Resize(args.image_size), transforms.ToTensor()]
transform_train = transforms.Compose(transform_list)
# 创建generic数据集对象
dataset = GenericNpyDataset(args.data_path, transform=transform_train, test_flag=False)
# 如果数据集不是ISIC或generic,则抛出未实现错误
else:
raise NotImplementedError(f"Your dataset {args.dataset} hasn't been implemented yet.")
## 定义PyTorch数据生成器
training_generator = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True)
# 返回训练数据生成器
return training_generator
def main():
# 解析命令行参数
args = parse_args()
# 创建检查点目录
checkpoint_dir = os.path.join(args.output_dir, 'checkpoints')
# 创建日志目录
logging_dir = os.path.join(args.output_dir, args.logging_dir)
# 如果目录不存在则创建
os.makedirs(checkpoint_dir, exist_ok=True)
# 初始化加速器
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
logging_dir=logging_dir,
)
# 如果是主进程则初始化跟踪器
if accelerator.is_main_process:
accelerator.init_trackers("med-seg-diff", config=vars(args))
## DEFINE MODEL ##
# 定义模型
model = Unet(
dim=args.dim,
image_size=args.image_size,
dim_mults=(1, 2, 4, 8),
mask_channels=args.mask_channels,
input_img_channels=args.input_img_channels,
self_condition=args.self_condition
)
## LOAD DATA ##
# 加载数据
data_loader = load_data(args)
# 如果需要缩放学习率,则重新计算学习率
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.batch_size * accelerator.num_processes
)
## Initialize optimizer
# 初始化优化器
if not args.use_lion:
optimizer = AdamW(
model.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
else:
optimizer = Lion(
model.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay
)
## TRAIN MODEL ##
counter = 0
# 准备模型、优化器和数据加载器
model, optimizer, data_loader = accelerator.prepare(
model, optimizer, data_loader
)
# 创建 MedSegDiff 对象
diffusion = MedSegDiff(
model,
timesteps=args.timesteps
).to(accelerator.device)
# 如果指定了加载模型的路径,则加载模型
if args.load_model_from is not None:
save_dict = torch.load(args.load_model_from)
diffusion.model.load_state_dict(save_dict['model_state_dict'])
optimizer.load_state_dict(save_dict['optimizer_state_dict'])
accelerator.print(f'Loaded from {args.load_model_from}')
## Iterate across training loop
# 遍历训练循环
for epoch in range(args.epochs):
running_loss = 0.0
print('Epoch {}/{}'.format(epoch + 1, args.epochs))
for (img, mask) in tqdm(data_loader):
with accelerator.accumulate(model):
loss = diffusion(mask, img)
running_loss += loss.item() * img.size(0)
accelerator.log({'loss': loss}) # Log loss to wandb
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
counter += 1
epoch_loss = running_loss / len(data_loader)
print('Training Loss : {:.4f}'.format(epoch_loss))
## INFERENCE ##
# 如果满足保存间隔条件,则保存模型
if epoch % args.save_every == 0:
torch.save({
'epoch': epoch,
'model_state_dict': diffusion.model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, os.path.join(checkpoint_dir, f'state_dict_epoch_{epoch}_loss_{epoch_loss}.pt'))
# 生成预测结果
pred = diffusion.sample(img).cpu().detach().numpy()
for tracker in accelerator.trackers:
if tracker.name == "wandb":
# 保存每个批次的一张图像
tracker.log(
{'pred-img-mask': [wandb.Image(pred[0, 0, :, :]), wandb.Image(img[0, 0, :, :]),
wandb.Image(mask[0, 0, :, :])]}
)
if __name__ == '__main__':
main()
.\lucidrains\med-seg-diff-pytorch\med_seg_diff_pytorch\dataset.py
import os
import numpy as np
# 设置环境变量,允许重复加载库
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
import torch
from torch.utils.data import Dataset
from PIL import Image
import pandas as pd
import random
import torchvision.transforms.functional as F
# 创建自定义数据集类 ISICDataset
class ISICDataset(Dataset):
def __init__(self, data_path, csv_file, img_folder, transform=None, training=True, flip_p=0.5):
# 读取 CSV 文件
df = pd.read_csv(os.path.join(data_path, csv_file), encoding='gbk')
self.img_folder = img_folder
self.name_list = df.iloc[:, 0].tolist()
self.label_list = df.iloc[:, 1].tolist()
self.data_path = data_path
self.transform = transform
self.training = training
self.flip_p = flip_p
def __len__(self):
return len(self.name_list)
def __getitem__(self, index):
"""Get the images"""
name = self.name_list[index] + '.jpg'
img_path = os.path.join(self.data_path, self.img_folder, name)
mask_name = name.split('.')[0] + '_Segmentation.png'
msk_path = os.path.join(self.data_path, self.img_folder, mask_name)
img = Image.open(img_path).convert('RGB')
mask = Image.open(msk_path).convert('L')
if self.training:
label = 0 if self.label_list[index] == 'benign' else 1
else:
label = int(self.label_list[index])
if self.transform:
# 保存随机状态,以便如果使用更复杂的转换,则将相同的转换应用于 mask 和 img
state = torch.get_rng_state()
img = self.transform(img)
torch.set_rng_state(state)
mask = self.transform(mask)
if random.random() < self.flip_p:
img = F.vflip(img)
mask = F.vflip(mask)
if self.training:
return (img, mask)
return (img, mask, label)
# 创建自定义数据集类 GenericNpyDataset
class GenericNpyDataset(torch.utils.data.Dataset):
def __init__(self, directory: str, transform, test_flag: bool = True):
'''
Genereic dataset for loading npy files.
The npy store 3D arrays with the first two dimensions being the image and the third dimension being the channels.
channel 0 is the image and the other channel is the label.
'''
super().__init__()
self.directory = os.path.expanduser(directory)
self.transform = transform
self.test_flag = test_flag
self.filenames = [x for x in os.listdir(self.directory) if x.endswith('.npy')]
def __getitem__(self, x: int):
fname = self.filenames[x]
npy_img = np.load(os.path.join(self.directory, fname))
img = npy_img[:, :, :1]
img = torch.from_numpy(img).permute(2, 0, 1)
mask = npy_img[:, :, 1:]
mask = np.where(mask > 0, 1, 0)
image = img[:, ...]
mask = torch.from_numpy(mask).permute(2, 0, 1).float()
if self.transform:
# 保存随机状态,以便如果使用更复杂的转换,则将相同的转换应用于 mask 和 img
state = torch.get_rng_state()
image = self.transform(image)
torch.set_rng_state(state)
mask = self.transform(mask)
if self.test_flag:
return image, mask, fname
return image, mask
def __len__(self) -> int:
return len(self.filenames)
.\lucidrains\med-seg-diff-pytorch\med_seg_diff_pytorch\med_seg_diff_pytorch.py
# 导入所需的库
import math
import copy
from random import random
from functools import partial
from collections import namedtuple
# 导入第三方库
from beartype import beartype
# 导入 PyTorch 库
import torch
from torch import nn, einsum
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from torch.fft import fft2, ifft2
# 导入 einops 库
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
# 导入 tqdm 库
from tqdm.auto import tqdm
# 定义常量
ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
# 辅助函数
# 判断变量是否存在
def exists(x):
return x is not None
# 返回默认值
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
# 返回输入本身
def identity(t, *args, **kwargs):
return t
# 标准化函数
# 将图像标准化到 -1 到 1 之间
def normalize_to_neg_one_to_one(img):
return img * 2 - 1
# 将标准化后的图像反标准化到 0 到 1 之间
def unnormalize_to_zero_to_one(t):
return (t + 1) * 0.5
# 小型辅助模块
# 残差模块
class Residual(Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
# 上采样模块
def Upsample(dim, dim_out = None):
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'nearest'),
nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
)
# 下采样模块
def Downsample(dim, dim_out = None):
return nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2),
nn.Conv2d(dim * 4, default(dim_out, dim), 1)
)
# 层归一化模块
class LayerNorm(Module):
def __init__(self, dim, bias = False):
super().__init__()
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) if bias else None
def forward(self, x):
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * (var + eps).rsqrt() * self.g + default(self.b, 0)
# 正弦位置编码模块
class SinusoidalPosEmb(Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
# 构建块模块
# 基础块模块
class Block(Module):
def __init__(self, dim, dim_out, groups = 8):
super().__init__()
self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, scale_shift = None):
x = self.proj(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
# ResNet 块模块
class ResnetBlock(Module):
def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
super().__init__()
self.mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out * 2)
) if exists(time_emb_dim) else None
self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb = None):
scale_shift = None
if exists(self.mlp) and exists(time_emb):
time_emb = self.mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
scale_shift = time_emb.chunk(2, dim = 1)
h = self.block1(x, scale_shift = scale_shift)
h = self.block2(h)
return h + self.res_conv(x)
# 前馈网络模块
def FeedForward(dim, mult = 4):
inner_dim = int(dim * mult)
# 返回一个包含多个层的神经网络模型
return nn.Sequential(
# 对输入数据进行层归一化
LayerNorm(dim),
# 1x1卷积层,将输入维度转换为inner_dim
nn.Conv2d(dim, inner_dim, 1),
# GELU激活函数
nn.GELU(),
# 1x1卷积层,将inner_dim维度转换为dim
nn.Conv2d(inner_dim, dim, 1),
)
class LinearAttention(Module):
# 定义线性注意力机制模块
def __init__(self, dim, heads = 4, dim_head = 32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.prenorm = LayerNorm(dim)
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(hidden_dim, dim, 1),
LayerNorm(dim)
)
def forward(self, x):
b, c, h, w = x.shape
x = self.prenorm(x)
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
q = q.softmax(dim = -2)
k = k.softmax(dim = -1)
q = q * self.scale
context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
return self.to_out(out)
class Attention(Module):
# 定义注意力机制模块
def __init__(self, dim, heads = 4, dim_head = 32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.prenorm = LayerNorm(dim)
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
x = self.prenorm(x)
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
q = q * self.scale
sim = einsum('b h d i, b h d j -> b h i j', q, k)
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h d j -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
return self.to_out(out)
class Transformer(Module):
# 定义变压器模块
def __init__(
self,
dim,
dim_head = 32,
heads = 4,
depth = 1
):
super().__init__()
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Residual(Attention(dim, dim_head = dim_head, heads = heads)),
Residual(FeedForward(dim))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x)
x = ff(x)
return x
# vision transformer for dynamic ff-parser
class ViT(Module):
# 定义视觉变压器模块
def __init__(
self,
dim,
*,
image_size,
patch_size,
channels = 3,
channels_out = None,
dim_head = 32,
heads = 4,
depth = 4,
):
super().__init__()
assert exists(image_size)
assert (image_size % patch_size) == 0
num_patches_height_width = image_size // patch_size
self.pos_emb = nn.Parameter(torch.zeros(dim, num_patches_height_width, num_patches_height_width))
channels_out = default(channels_out, channels)
patch_dim = channels * (patch_size ** 2)
output_patch_dim = channels_out * (patch_size ** 2)
self.to_tokens = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = patch_size, p2 = patch_size),
nn.Conv2d(patch_dim, dim, 1),
LayerNorm(dim)
)
self.transformer = Transformer(
dim = dim,
dim_head = dim_head,
depth = depth
)
self.to_patches = nn.Sequential(
LayerNorm(dim),
nn.Conv2d(dim, output_patch_dim, 1),
Rearrange('b (c p1 p2) h w -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size),
)
nn.init.zeros_(self.to_patches[-2].weight)
nn.init.zeros_(self.to_patches[-2].bias)
# 定义前向传播函数,接收输入 x
def forward(self, x):
# 将输入 x 转换为 tokens
x = self.to_tokens(x)
# 将输入 x 与位置编码相加
x = x + self.pos_emb
# 使用 Transformer 处理输入 x
x = self.transformer(x)
# 将处理后的结果转换为 patches
return self.to_patches(x)
# 定义一个名为 Conditioning 的类,继承自 Module 类
class Conditioning(Module):
# 初始化函数,接受多个参数
def __init__(
self,
fmap_size,
dim,
dynamic = True,
image_size = None,
dim_head = 32,
heads = 4,
depth = 4,
patch_size = 16
):
# 调用父类的初始化函数
super().__init__()
# 创建一个可学习的参数 ff_parser_attn_map,维度为 (dim, fmap_size, fmap_size)
self.ff_parser_attn_map = nn.Parameter(torch.ones(dim, fmap_size, fmap_size))
# 设置是否为动态模式
self.dynamic = dynamic
# 如果是动态模式
if dynamic:
# 创建一个 ViT 模型,用于动态调整 ff_parser_attn_map
self.to_dynamic_ff_parser_attn_map = ViT(
dim = dim,
channels = dim * 2 * 2, # 输入和条件的通道数,考虑到复数(实部和虚部)
channels_out = dim,
image_size = image_size,
patch_size = patch_size,
heads = heads,
dim_head = dim_head
)
# 创建 LayerNorm 层,用于输入和条件的归一化
self.norm_input = LayerNorm(dim, bias = True)
self.norm_condition = LayerNorm(dim, bias = True)
# 创建一个 ResnetBlock 模块
self.block = ResnetBlock(dim, dim)
# 前向传播函数,接受输入 x 和条件 c
def forward(self, x, c):
# 获取 ff_parser_attn_map 参数
ff_parser_attn_map = self.ff_parser_attn_map
# 对输入 x 进行二维傅立叶变换
dtype = x.dtype
x = fft2(x)
# 如果是动态模式
if self.dynamic:
# 对条件 c 进行二维傅立叶变换
c_complex = fft2(c)
x_as_real, c_as_real = map(torch.view_as_real, (x, c_complex))
x_as_real, c_as_real = map(lambda t: rearrange(t, 'b d h w ri -> b (d ri) h w'), (x_as_real, c_as_real))
# 将 x 和 c 连接起来
to_dynamic_input = torch.cat((x_as_real, c_as_real), dim = 1)
# 使用 ViT 模型调整 ff_parser_attn_map
dynamic_ff_parser_attn_map = self.to_dynamic_ff_parser_attn_map(to_dynamic_input)
# 更新 ff_parser_attn_map
ff_parser_attn_map = ff_parser_attn_map + dynamic_ff_parser_attn_map
# 使用 ff_parser_attn_map 对 x 进行调制
x = x * ff_parser_attn_map
# 对 x 进行逆二维傅立叶变换,并取实部
x = ifft2(x).real
x = x.type(dtype)
# 在论文中的公式 3
# 对 x 和 c 进���归一化,然后相乘再乘以 c
normed_x = self.norm_input(x)
normed_c = self.norm_condition(c)
c = (normed_x * normed_c) * c
# 添加一个额外的块以允许更多信息的整合
# 在 Condition 块之后有一个下采样(但也许有一个更好的地方可以进行条件化,而不是就在下采样之前)
# 返回经过块处理后的 c
return self.block(c)
# 定义一个名为 Unet 的类,继承自 Module 类
@beartype
class Unet(Module):
# 初始化函数,接受多个参数
def __init__(
self,
dim,
image_size,
mask_channels = 1,
input_img_channels = 3,
init_dim = None,
out_dim = None,
dim_mults: tuple = (1, 2, 4, 8),
full_self_attn: tuple = (False, False, False, True),
attn_dim_head = 32,
attn_heads = 4,
mid_transformer_depth = 1,
self_condition = False,
resnet_block_groups = 8,
conditioning_klass = Conditioning,
skip_connect_condition_fmaps = False, # 是否在后续解码器上采样部分连接条件 fmaps
dynamic_ff_parser_attn_map = False, # 允许 ff-parser 根据输入动态调整。暂时排除条件
conditioning_kwargs: dict = dict(
dim_head = 32,
heads = 4,
depth = 4,
patch_size = 16
)
):
# 调用父类的构造函数
super().__init__()
# 设置图像大小
self.image_size = image_size
# 确定维度
# 输入图像通道数
self.input_img_channels = input_img_channels
# mask 通道数
self.mask_channels = mask_channels
# 是否自身条件
self.self_condition = self_condition
# 输出通道数为 mask 通道数
output_channels = mask_channels
# 如果有自身条件,mask 通道数变为原来的两倍,否则不变
mask_channels = mask_channels * (2 if self_condition else 1)
# 初始化维度为默认维度或者给定的维度
init_dim = default(init_dim, dim)
# 初始化卷积层,输入为 mask 通道数,输出为 init_dim,卷积核大小为 7x7,填充为 3
self.init_conv = nn.Conv2d(mask_channels, init_dim, 7, padding = 3)
# 条件初始化卷积层,输入为输入图像通道数,输出为 init_dim,卷积核大小为 7x7,填充为 3
self.cond_init_conv = nn.Conv2d(input_img_channels, init_dim, 7, padding = 3)
# 维度列表,包括初始化维度和后续维度的倍数
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
# 输入输出维度对
in_out = list(zip(dims[:-1], dims[1:]))
# 部分 ResnetBlock 类的初始化
block_klass = partial(ResnetBlock, groups = resnet_block_groups)
# 时间嵌入维度
time_dim = dim * 4
# 时间 MLP 模型
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim)
)
# 注意力相关参数
attn_kwargs = dict(
dim_head = attn_dim_head,
heads = attn_heads
)
# conditioner 设置
if conditioning_klass == Conditioning:
conditioning_klass = partial(
Conditioning,
dynamic = dynamic_ff_parser_attn_map,
**conditioning_kwargs
)
# 层
num_resolutions = len(in_out)
assert len(full_self_attn) == num_resolutions
# 条件器列表
self.conditioners = ModuleList([])
# 是否跳过连接条件特征图
self.skip_connect_condition_fmaps = skip_connect_condition_fmaps
# 下采样编码块
self.downs = ModuleList([])
curr_fmap_size = image_size
for ind, ((dim_in, dim_out), full_attn) in enumerate(zip(in_out, full_self_attn)):
is_last = ind >= (num_resolutions - 1)
attn_klass = Attention if full_attn else LinearAttention
self.conditioners.append(conditioning_klass(curr_fmap_size, dim_in, image_size = curr_fmap_size))
self.downs.append(ModuleList([
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
Residual(attn_klass(dim_in, **attn_kwargs)),
Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
]))
if not is_last:
curr_fmap_size //= 2
# 中间块
mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
self.mid_transformer = Transformer(mid_dim, depth = mid_transformer_depth, **attn_kwargs)
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
# 条件编码路径与主编码路径相同
self.cond_downs = copy.deepcopy(self.downs)
self.cond_mid_block1 = copy.deepcopy(self.mid_block1)
# 上采样解码块
self.ups = ModuleList([])
for ind, ((dim_in, dim_out), full_attn) in enumerate(zip(reversed(in_out), reversed(full_self_attn))):
is_last = ind == (len(in_out) - 1)
attn_klass = Attention if full_attn else LinearAttention
skip_connect_dim = dim_in * (2 if self.skip_connect_condition_fmaps else 1)
self.ups.append(ModuleList([
block_klass(dim_out + skip_connect_dim, dim_out, time_emb_dim = time_dim),
block_klass(dim_out + skip_connect_dim, dim_out, time_emb_dim = time_dim),
Residual(attn_klass(dim_out, **attn_kwargs)),
Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
]))
# 投影到预测
self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
self.final_conv = nn.Conv2d(dim, output_channels, 1)
# 定义前向传播函数,接受输入 x、时间 time、条件 cond、自身条件 x_self_cond
def forward(
self,
x,
time,
cond,
x_self_cond = None
):
# 获取输入 x 的数据类型和是否跳过连接的条件特征图
dtype, skip_connect_c = x.dtype, self.skip_connect_condition_fmaps
# 如果存在自身条件,将其与输入 x 进行拼接
if self.self_condition:
x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
x = torch.cat((x_self_cond, x), dim = 1)
# 对输入 x 进行初始卷积
x = self.init_conv(x)
# 复制输入 x 作为中间结果
r = x.clone()
# 对条件 cond 进行初始卷积
c = self.cond_init_conv(cond)
# 对时间 time 进行多层感知机处理
t = self.time_mlp(time)
# 初始化中间结果列表
h = []
# 遍历下采样模块、条件下采样模块和条件器
for (block1, block2, attn, downsample), (cond_block1, cond_block2, cond_attn, cond_downsample), conditioner in zip(self.downs, self.cond_downs, self.conditioners):
# 对输入 x 进行第一个块的处理
x = block1(x, t)
# 对条件 c 进行第一个块的处理
c = cond_block1(c, t)
# 将当前处理结果加入中间结果列表
h.append([x, c] if skip_connect_c else [x])
# 对输入 x 进行第二个块的处理
x = block2(x, t)
# 对条件 c 进行第二个块的处理
c = cond_block2(c, t)
# 对输入 x 进行注意力机制处理
x = attn(x)
# 对条件 c 进行注意力机制处理
c = cond_attn(c)
# 使用条件器对条件 c 进行处理
c = conditioner(x, c)
# 将当前处理结果加入中间结果列表
h.append([x, c] if skip_connect_c else [x])
# 对输入 x 进行下采样
x = downsample(x)
# 对条件 c 进行下采样
c = cond_downsample(c)
# 对输入 x 进行中间块1的处理
x = self.mid_block1(x, t)
# 对条件 c 进行中间块1的处理
c = self.cond_mid_block1(c, t)
# 将条件 c 加到输入 x 上
x = x + c
# 对输入 x 进行中间变换器处理
x = self.mid_transformer(x)
# 对输入 x 进行中间块2的处理
x = self.mid_block2(x, t)
# 遍历上采样模块
for block1, block2, attn, upsample in self.ups:
# 将中间结果与 h 中的结果拼接
x = torch.cat((x, *h.pop()), dim = 1)
# 对输入 x 进行第一个块的处理
x = block1(x, t)
# 将中间结果与 h 中的结果拼接
x = torch.cat((x, *h.pop()), dim = 1)
# 对输入 x 进行第二个块的处理
x = block2(x, t)
# 对输入 x 进行注意力机制处理
x = attn(x)
# 对输入 x 进行上采样
x = upsample(x)
# 将输入 x 与初始输入 r 拼接
x = torch.cat((x, r), dim = 1)
# 对拼接后的结果进行最终残差块处理
x = self.final_res_block(x, t)
# 返回最终卷积结果
return self.final_conv(x)
# 高斯扩散训练器类
# 从输入张量 a 中提取指定索引 t 对应的值,并根据 x_shape 的形状重新组织输出
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
# 线性的 beta 调度函数,根据总步数 timesteps 计算出 beta 的线性变化范围
def linear_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
# 余弦形式的 beta 调度函数,根据总步数 timesteps 和参数 s 计算出 beta 的余弦变化范围
def cosine_beta_schedule(timesteps, s=0.008):
"""
余弦调度函数
参考 https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
# 医学分割扩散模块类,继承自 Module 类
class MedSegDiff(Module):
def __init__(
self,
model,
*,
timesteps=1000,
sampling_timesteps=None,
objective='pred_noise',
beta_schedule='cosine',
ddim_sampling_eta=1.
):
# 调用父类的构造函数
super().__init__()
# 如果传入的模型不是 Unet 类型,则取其 module 属性
self.model = model if isinstance(model, Unet) else model.module
# 获取模型的输入图像通道数、掩模通道数、自身条件、图像大小等属性
self.input_img_channels = self.model.input_img_channels
self.mask_channels = self.model.mask_channels
self.self_condition = self.model.self_condition
self.image_size = self.model.image_size
# 设置目标类型
self.objective = objective
# 检查目标类型是否合法
assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
# 根据 beta_schedule 选择不同的 beta 调度
if beta_schedule == 'linear':
betas = linear_beta_schedule(timesteps)
elif beta_schedule == 'cosine':
betas = cosine_beta_schedule(timesteps)
else:
raise ValueError(f'unknown beta schedule {beta_schedule}')
# 计算 alphas
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
# 获取时间步数
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
# 设置采样相关参数
self.sampling_timesteps = default(sampling_timesteps, timesteps) # 默认采样时间步数为训练时间步数
assert self.sampling_timesteps <= timesteps
self.is_ddim_sampling = self.sampling_timesteps < timesteps
self.ddim_sampling_eta = ddim_sampling_eta
# 注册缓冲区函数
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
# 注册缓冲区
register_buffer('betas', betas)
register_buffer('alphas_cumprod', alphas_cumprod)
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# 计算扩散 q(x_t | x_{t-1}) 和其他参数
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# 计算后验 q(x_{t-1} | x_t, x_0) 参数
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
register_buffer('posterior_variance', posterior_variance)
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
@property
def device(self):
# 返回参数的设备信息
return next(self.parameters()).device
def predict_start_from_noise(self, x_t, t, noise):
# 预测起始值从噪声
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def predict_noise_from_start(self, x_t, t, x0):
# 从起始值预测噪声
return (
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
)
def predict_v(self, x_start, t, noise):
# 预测 v
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
)
# 根据给定的输入 x_t, t 和 v 预测起始值
def predict_start_from_v(self, x_t, t, v):
return (
# 使用累积平方根系数乘积提取 t 时刻的值,与输入 x_t 相乘
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
# 使用累积平方根系数乘积提取 t 时刻的值,与输入 v 相乘
extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
)
# 计算后验分布的均值和方差
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
# 提取 t 时刻的系数1,与输入 x_start 相乘
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
# 提取 t 时刻的系数2,与输入 x_t 相乘
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
# 提取 t 时刻的后验方差
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
# 提取 t 时刻的修剪后的后验对数方差
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
# 模型预测函数,根据不同的目标类型进行预测
def model_predictions(self, x, t, c, x_self_cond = None, clip_x_start = False):
model_output = self.model(x, t, c, x_self_cond)
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
if self.objective == 'pred_noise':
pred_noise = model_output
x_start = self.predict_start_from_noise(x, t, pred_noise)
x_start = maybe_clip(x_start)
elif self.objective == 'pred_x0':
x_start = model_output
x_start = maybe_clip(x_start)
pred_noise = self.predict_noise_from_start(x, t, x_start)
elif self.objective == 'pred_v':
v = model_output
x_start = self.predict_start_from_v(x, t, v)
x_start = maybe_clip(x_start)
pred_noise = self.predict_noise_from_start(x, t, x_start)
return ModelPrediction(pred_noise, x_start)
# 计算均值和方差,可选择是否对去噪后的值进行裁剪
def p_mean_variance(self, x, t, c, x_self_cond = None, clip_denoised = True):
preds = self.model_predictions(x, t, c, x_self_cond)
x_start = preds.pred_x_start
if clip_denoised:
x_start.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
return model_mean, posterior_variance, posterior_log_variance, x_start
# 生成样本,根��输入 x, t, c 生成预测图像
@torch.no_grad()
def p_sample(self, x, t, c, x_self_cond = None, clip_denoised = True):
b, *_, device = *x.shape, x.device
batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long)
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, c = c, x_self_cond = x_self_cond, clip_denoised = clip_denoised)
noise = torch.randn_like(x) if t > 0 else 0. # 若 t == 0 则无噪声
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
return pred_img, x_start
# 循环生成样本,根据给定的形状和条件
@torch.no_grad()
def p_sample_loop(self, shape, cond):
batch, device = shape[0], self.betas.device
img = torch.randn(shape, device = device)
x_start = None
for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
self_cond = x_start if self.self_condition else None
img, x_start = self.p_sample(img, t, cond, self_cond)
img = unnormalize_to_zero_to_one(img)
return img
# 禁用梯度计算
@torch.no_grad()
# 从给定形状和条件图像中生成 DDIM 采样结果
def ddim_sample(self, shape, cond_img, clip_denoised = True):
# 获取形状参数和设备信息
batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
# 生成时间步长序列
times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
# 生成随机初始图像
img = torch.randn(shape, device = device)
x_start = None
# 遍历时间步长对
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
self_cond = x_start if self.self_condition else None
pred_noise, x_start, *_ = self.model_predictions(img, time_cond, cond_img, self_cond, clip_x_start = clip_denoised)
if time_next < 0:
img = x_start
continue
alpha = self.alphas_cumprod[time]
alpha_next = self.alphas_cumprod[time_next]
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c = (1 - alpha_next - sigma ** 2).sqrt()
noise = torch.randn_like(img)
img = x_start * alpha_next.sqrt() + \
c * pred_noise + \
sigma * noise
# 将图像还原到 [0, 1] 范围内
img = unnormalize_to_zero_to_one(img)
return img
# 生成采样结果
@torch.no_grad()
def sample(self, cond_img):
batch_size, device = cond_img.shape[0], self.device
cond_img = cond_img.to(self.device)
image_size, mask_channels = self.image_size, self.mask_channels
sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
return sample_fn((batch_size, mask_channels, image_size, image_size), cond_img)
# 生成 Q 采样结果
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
# 计算 P 损失
def p_losses(self, x_start, t, cond, noise = None):
b, c, h, w = x_start.shape
noise = default(noise, lambda: torch.randn_like(x_start))
# 生成噪声样本
x = self.q_sample(x_start = x_start, t = t, noise = noise)
# 如果进行自条件生成,50% 的时间,从当前时间预测 x_start,并使用 unet 进行条件生成
# 这种技术会使训练速度减慢 25%,但似乎显著降低 FID
x_self_cond = None
if self.self_condition and random() < 0.5:
with torch.no_grad():
# 预测 x_0
x_self_cond = self.model_predictions(x, t, cond).pred_x_start
x_self_cond.detach_()
# 预测并进行梯度下降
model_out = self.model(x, t, cond, x_self_cond)
if self.objective == 'pred_noise':
target = noise
elif self.objective == 'pred_x0':
target = x_start
elif self.objective == 'pred_v':
v = self.predict_v(x_start, t, noise)
target = v
else:
raise ValueError(f'unknown objective {self.objective}')
return F.mse_loss(model_out, target)
# 定义一个前向传播函数,接受输入图像、条件图像以及其他参数
def forward(self, img, cond_img, *args, **kwargs):
# 如果输入图像维度为3,则将其重排为'b h w -> b 1 h w'
if img.ndim == 3:
img = rearrange(img, 'b h w -> b 1 h w')
# 如果条件图像维度为3,则将其重排为'b h w -> b 1 h w'
if cond_img.ndim == 3:
cond_img = rearrange(cond_img, 'b h w -> b 1 h w')
# 获取设备信息并将输入图像和条件图像移动到该设备上
device = self.device
img, cond_img = img.to(device), cond_img.to(device)
# 获取输入图像的形状信息
b, c, h, w, device, img_size, img_channels, mask_channels = *img.shape, img.device, self.image_size, self.input_img_channels, self.mask_channels
# 断言输入图像的高度和宽度必须为img_size
assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
# 断言条件图像的通道数必须为img_channels
assert cond_img.shape[1] == img_channels, f'your input medical must have {img_channels} channels'
# 断言输入图像的通道数必须为mask_channels
assert img.shape[1] == mask_channels, f'the segmented image must have {mask_channels} channels'
# 生成一个随机整数张量,范围为[0, num_timesteps),形状为(b,)
times = torch.randint(0, self.num_timesteps, (b,), device=device).long()
# 对输入图像进行归一化到[-1, 1]范围内
img = normalize_to_neg_one_to_one(img)
# 调用p_losses函数计算损失并返回结果
return self.p_losses(img, times, cond_img, *args, **kwargs)
.\lucidrains\med-seg-diff-pytorch\med_seg_diff_pytorch\__init__.py
# 从med_seg_diff_pytorch.med_seg_diff_pytorch模块中导入MedSegDiff和Unet类
from med_seg_diff_pytorch.med_seg_diff_pytorch import MedSegDiff, Unet

MedSegDiff - Pytorch
Implementation of MedSegDiff in Pytorch - SOTA medical segmentation out of Baidu using DDPM and enhanced conditioning on the feature level, with filtering of features in fourier space.
Appreciation
-
StabilityAI for the generous sponsorship, as well as my other sponsors out there
-
Isamu and Daniel for adding a training script for a skin lesion dataset!
Install
$ pip install med-seg-diff-pytorch
Usage
import torch
from med_seg_diff_pytorch import Unet, MedSegDiff
model = Unet(
dim = 64,
image_size = 128,
mask_channels = 1, # segmentation has 1 channel
input_img_channels = 3, # input images have 3 channels
dim_mults = (1, 2, 4, 8)
)
diffusion = MedSegDiff(
model,
timesteps = 1000
).cuda()
segmented_imgs = torch.rand(8, 1, 128, 128) # inputs are normalized from 0 to 1
input_imgs = torch.rand(8, 3, 128, 128)
loss = diffusion(segmented_imgs, input_imgs)
loss.backward()
# after a lot of training
pred = diffusion.sample(input_imgs) # pass in your unsegmented images
pred.shape # predicted segmented images - (8, 3, 128, 128)
Training
Command to run
accelerate launch driver.py --mask_channels=1 --input_img_channels=3 --image_size=64 --data_path='./data' --dim=64 --epochs=100 --batch_size=1 --scale_lr --gradient_accumulation_steps=4
If you want to add in self condition where we condition with the mask we have so far, do --self_condition
Todo
Citations
@article{Wu2022MedSegDiffMI,
title = {MedSegDiff: Medical Image Segmentation with Diffusion Probabilistic Model},
author = {Junde Wu and Huihui Fang and Yu Zhang and Yehui Yang and Yanwu Xu},
journal = {ArXiv},
year = {2022},
volume = {abs/2211.00611}
}
@inproceedings{Hoogeboom2023simpleDE,
title = {simple diffusion: End-to-end diffusion for high resolution images},
author = {Emiel Hoogeboom and Jonathan Heek and Tim Salimans},
year = {2023}
}
.\lucidrains\med-seg-diff-pytorch\sample.py
# 导入所需的库
import os
import argparse
from tqdm import tqdm
import torch
import torchvision.transforms as transforms
from med_seg_diff_pytorch import Unet, MedSegDiff
from med_seg_diff_pytorch.dataset import ISICDataset, GenericNpyDataset
from accelerate import Accelerator
import skimage.io as io
## 解析命令行参数 ##
def parse_args():
# 创建参数解析器
parser = argparse.ArgumentParser()
# 添加命令行参数
parser.add_argument('-od', '--output_dir', type=str, default="output", help="Output dir.")
parser.add_argument('-ld', '--logging_dir', type=str, default="logs", help="Logging dir.")
parser.add_argument('-mp', '--mixed_precision', type=str, default="no", choices=["no", "fp16", "bf16"],
help="Whether to do mixed precision")
parser.add_argument('-img', '--img_folder', type=str, default='ISBI2016_ISIC_Part3B_Training_Data',
help='The image file path from data_path')
parser.add_argument('-csv', '--csv_file', type=str, default='ISBI2016_ISIC_Part3B_Training_GroundTruth.csv',
help='The csv file to load in from data_path')
parser.add_argument('-sc', '--self_condition', action='store_true', help='Whether to do self condition')
parser.add_argument('-ic', '--mask_channels', type=int, default=1, help='input channels for training (default: 3)')
parser.add_argument('-c', '--input_img_channels', type=int, default=3,
help='output channels for training (default: 3)')
parser.add_argument('-is', '--image_size', type=int, default=128, help='input image size (default: 128)')
parser.add_argument('-dd', '--data_path', default='./data', help='directory of input image')
parser.add_argument('-d', '--dim', type=int, default=64, help='dim (default: 64)')
parser.add_argument('-e', '--epochs', type=int, default=10000, help='number of epochs (default: 10000)')
parser.add_argument('-bs', '--batch_size', type=int, default=8, help='batch size to train on (default: 8)')
parser.add_argument('--timesteps', type=int, default=1000, help='number of timesteps (default: 1000)')
parser.add_argument('-ds', '--dataset', default='generic', help='Dataset to use')
parser.add_argument('--save_every', type=int, default=100, help='save_every n epochs (default: 100)')
parser.add_argument('--num_ens', type=int, default=5,
help='number of times to sample to make an ensable of predictions like in the paper (default: 5)')
parser.add_argument('--load_model_from', default=None, help='path to pt file to load from')
parser.add_argument('--save_uncertainty', action='store_true',
help='Whether to store the uncertainty in predictions (only works for ensablmes)')
# 解析命令行参数并返回
return parser.parse_args()
def load_data(args):
# 加载数据集
if args.dataset == 'ISIC':
# 定义数据转换
transform_list = [transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(), ]
transform_train = transforms.Compose(transform_list)
# 创建 ISIC 数据集对象
dataset = ISICDataset(args.data_path, args.csv_file, args.img_folder, transform=transform_train, training=False,
flip_p=0.5)
elif args.dataset == 'generic':
# 定义数据转换
transform_list = [transforms.ToPILImage(), transforms.Resize(args.image_size), transforms.ToTensor()]
transform_train = transforms.Compose(transform_list)
# 创建通用 Npy 数据集对象
dataset = GenericNpyDataset(args.data_path, transform=transform_train, test_flag=True)
else:
# 抛出未实现的错误
raise NotImplementedError(f"Your dataset {args.dataset} hasn't been implemented yet.")
## 定义 PyTorch 数据生成器
training_generator = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False)
return training_generator
def main():
# 解析命令行参数
args = parse_args()
# 设置日志目录
logging_dir = os.path.join(args.output_dir, args.logging_dir)
inference_dir = os.path.join(args.output_dir, 'inference')
# 创建推断目录
os.makedirs(inference_dir, exist_ok=True)
# 创建加速器对象,用于混合精度训练
accelerator = Accelerator(
mixed_precision=args.mixed_precision,
)
# 定义模型
model = Unet(
dim=args.dim,
image_size=args.image_size,
dim_mults=(1, 2, 4, 8),
mask_channels=args.mask_channels,
input_img_channels=args.input_img_channels,
self_condition=args.self_condition
)
# 加载数据
data_loader = load_data(args)
# 创建 MedSegDiff 对象,用于扩散过程
diffusion = MedSegDiff(
model,
timesteps=args.timesteps
).to(accelerator.device)
# 如果指定了加载模型的路径,则加载模型参数
if args.load_model_from is not None:
save_dict = torch.load(args.load_model_from)
diffusion.model.load_state_dict(save_dict['model_state_dict'])
# 遍历数据加载器中的数据
for (imgs, masks, fnames) in tqdm(data_loader):
# 预先分配预测结果的空间
preds = torch.zeros((imgs.shape[0], args.num_ens, imgs.shape[2], imgs.shape[3]))
# 对每个样本进行多次采样
for i in range(args.num_ens):
preds[:, i:i+1, :, :] = diffusion.sample(imgs).cpu().detach()
# 计算预测结果的均值和标准差
preds_mean = preds.mean(dim=1)
preds_std = preds.std(dim=1)
# 保存预测结果
for idx in range(preds.shape[0]):
io.imsave(os.path.join(inference_dir, fnames[idx].replace('.npy', '.png')), preds_mean[idx, :, :])
# 如果需要保存不确定性信息,则保存预测结果的标准差
if args.save_uncertainty:
io.imsave(os.path.join(inference_dir, fnames[idx].replace('.npy', '_std.png')), preds_std[idx, :, :])
# 如果当前脚本被直接执行,则调用主函数
if __name__ == '__main__':
main()


浙公网安备 33010602011771号