3D UNet 伪代码实现
@
网络整体结构
class VideoConditionalUNet(nn.Module):
def __init__(self, in_channels=4, out_channels=3, condition_dim=256):
super().__init__()
# 时间嵌入
self.time_embed = TimeEmbedding(projection_dim=512)
# 输入卷积
self.input_conv = nn.Conv3d(in_channels, 64, kernel_size=3, padding=1)
# 下采样路径
self.down1 = DownBlock(64, 128, condition_dim, dropout=0.1)
self.down2 = DownBlock(128, 256, condition_dim, dropout=0.2)
self.down3 = DownBlock(256, 512, condition_dim, dropout=0.3)
# 最底层
self.mid_block1 = ResBlockWithAttention(512, condition_dim, num_heads=8)
self.mid_block2 = ResBlock(512, condition_dim)
# 上采样路径
self.up1 = UpBlock(512, 256, condition_dim, dropout=0.3)
self.up2 = UpBlock(256, 128, condition_dim, dropout=0.2)
self.up3 = UpBlock(128, 64, condition_dim, dropout=0.1)
# 输出层
self.output_conv = nn.Sequential(
GroupNorm(64, 64),
nn.SiLU(),
nn.Conv3d(64, out_channels, kernel_size=3, padding=1)
)
def forward(self, x, t, condition):
"""
x: 带噪声的视频 (B, C, T, H, W)
t: 时间步 (B)
condition: 条件向量 (B, D)
"""
# 时间嵌入
t_emb = self.time_embed(t)
# 初始变换
x = self.input_conv(x)
# 下采样路径
h1 = self.down1(x, t_emb, condition)
h2 = self.down2(h1, t_emb, condition)
h3 = self.down3(h2, t_emb, condition)
# 最底层
h = self.mid_block1(h3, t_emb, condition)
h = self.mid_block2(h, t_emb, condition)
# 上采样路径
h = self.up1(h, t_emb, condition, h3) # 添加跳跃连接
h = self.up2(h, t_emb, condition, h2) # 添加跳跃连接
h = self.up3(h, t_emb, condition, h1) # 添加跳跃连接
# 输出
return self.output_conv(h)
时间嵌入层
class TimeEmbedding(nn.Module):
def __init__(self, projection_dim=512):
super().__init__()
self.fc1 = nn.Linear(projection_dim, projection_dim)
self.activation = nn.SiLU()
self.fc2 = nn.Linear(projection_dim, projection_dim)
def forward(self, t):
# 正弦位置编码
half_dim = projection_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
emb = t[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# 非线性变换
emb = self.fc1(emb)
emb = self.activation(emb)
return self.fc2(emb)
下、上采样块
class DownBlock(nn.Module):
def __init__(self, in_c, out_c, cond_dim, dropout=0.1):
super().__init__()
self.res1 = ResBlock(in_c, out_c, cond_dim)
self.attn = SpatialTemporalAttention(out_c)
self.res2 = ResBlock(out_c, out_c, cond_dim)
self.downsample = nn.Conv3d(out_c, out_c, kernel_size=3, stride=2, padding=1)
self.dropout = nn.Dropout3d(dropout) if dropout > 0 else nn.Identity()
def forward(self, x, t_emb, condition):
# 残差块
h = self.res1(x, t_emb, condition)
# 注意力
h = self.attn(h)
# 第二残差块
h = self.res2(h, t_emb, condition)
# 下采样
h = self.downsample(h)
return self.dropout(h)
class UpBlock(nn.Module):
def __init__(self, in_c, out_c, cond_dim, dropout=0.1):
super().__init__()
self.upsample = nn.ConvTranspose3d(in_c, in_c, kernel_size=3, stride=2, padding=1, output_padding=1)
self.res1 = ResBlock(in_c*2, in_c, cond_dim) # 处理跳跃连接
self.attn = SpatialTemporalAttention(in_c)
self.res2 = ResBlock(in_c, out_c, cond_dim)
self.dropout = nn.Dropout3d(dropout) if dropout > 0 else nn.Identity()
def forward(self, x, t_emb, condition, skip=None):
# 上采样
x = self.upsample(x)
# 跳跃连接
if skip is not None:
x = torch.cat([x, skip], dim=1)
# 残差块
x = self.res1(x, t_emb, condition)
# 注意力
x = self.attn(x)
# 第二残差块
x = self.res2(x, t_emb, condition)
return self.dropout(x)
残差块
class ResBlock(nn.Module):
def __init__(self, in_c, out_c, cond_dim):
super().__init__()
self.norm1 = GroupNorm(in_c)
self.conv1 = nn.Conv3d(in_c, out_c, 3, padding=1)
# 条件转换
self.time_dense = nn.Linear(cond_dim, out_c)
self.condition_dense = nn.Linear(cond_dim, out_c)
# 第二层
self.norm2 = GroupNorm(out_c)
self.conv2 = nn.Conv3d(out_c, out_c, 3, padding=1)
# 输入输出通道不匹配时的短路连接
self.shortcut = nn.Identity()
if in_c != out_c:
self.shortcut = nn.Sequential(
nn.Conv3d(in_c, out_c, kernel_size=1),
GroupNorm(out_c)
)
def forward(self, x, t_emb, condition):
h = self.norm1(x)
h = nn.SiLU()(h)
h = self.conv1(h)
# 条件整合
cond_emb = self.condition_dense(condition) # (B, D) -> (B, out_c)
time_emb = self.time_dense(t_emb) # (B, D) -> (B, out_c)
# 时空扩展
B, C, T, H, W = h.shape
cond_emb = cond_emb.view(B, C, 1, 1, 1).expand_as(h)
time_emb = time_emb.view(B, C, 1, 1, 1).expand_as(h)
h = h + cond_emb + time_emb
h = self.norm2(h)
h = nn.SiLU()(h)
h = self.conv2(h)
return h + self.shortcut(x)
class ResBlockWithAttention(nn.Module):
def __init__(self, ch, cond_dim, num_heads=8):
super().__init__()
self.res_block = ResBlock(ch, ch, cond_dim)
self.attention = CrossAttention(ch, cond_dim, num_heads=num_heads)
def forward(self, x, t_emb, condition):
x = self.res_block(x, t_emb, condition)
x = self.attention(x, condition)
return x
注意力模块
class SpatialTemporalAttention(nn.Module):
def forward(self, x):
"""伪代码:3D空间-时间自注意力"""
# 重塑为时空序列 (B, T*H*W, C)
B, C, T, H, W = x.shape
h = x.permute(0, 2, 3, 4, 1).reshape(B, T*H*W, C)
# 自注意力
h = multihead_self_attention(h)
# 重塑回原始形状
return h.reshape(B, T, H, W, C).permute(0, 4, 1, 2, 3)
class CrossAttention(nn.Module):
def __init__(self, ch, cond_dim, num_heads=8):
super().__init__()
self.query_proj = nn.Linear(ch, ch)
self.key_proj = nn.Linear(cond_dim, ch)
self.value_proj = nn.Linear(cond_dim, ch)
self.out_proj = nn.Linear(ch, ch)
self.num_heads = num_heads
def forward(self, x, condition):
"""伪代码:交叉注意力"""
# 输入特征 (B, C, T, H, W) -> (B, T*H*W, C)
B, C, T, H, W = x.shape
x_seq = x.permute(0, 2, 3, 4, 1).reshape(B, T*H*W, C)
# 查询变换
Q = self.query_proj(x_seq)
# 键值变换
K = self.key_proj(condition) # (B, D) -> (B, ch)
V = self.value_proj(condition) # (B, D) -> (B, ch)
# 注意力计算
# 实际操作中需分多头和缩放点积
attn_scores = torch.matmul(Q, K.unsqueeze(-1)).squeeze(-1) # (B, seq, 1)
attn_weights = softmax(attn_scores, dim=-1)
# 加权平均
attn_output = torch.sum(attn_weights.unsqueeze(-1) * V.unsqueeze(1), dim=-2)
# 投影输出
out = self.out_proj(attn_output)
# 重塑回原始形状
return out.reshape(B, T, H, W, C).permute(0, 4, 1, 2, 3) + x
组归一化层
class GroupNorm(nn.Module):
"""更稳定的3D组归一化变体"""
def __init__(self, num_channels, groups=32):
super().__init__()
self.gn = nn.GroupNorm(groups, num_channels)
def forward(self, x):
# 输入: (B, C, T, H, W)
x = x.permute(0, 2, 1, 3, 4) # (B, T, C, H, W)
b, t, c, h, w = x.shape
x = x.reshape(b*t, c, h, w)
x = self.gn(x)
x = x.reshape(b, t, c, h, w).permute(0, 2, 1, 3, 4)
return x

浙公网安备 33010602011771号