语音中的Encoder-Decoder学习——Unet代码阅读

import argparse
import os

import json5
import numpy as np
import torch
from torch.utils.data import DataLoader
from util.utils import initialize_config


def main(config, resume):
    # 意思是定义了一个函数main,这个函数里面有两个需要玩家提供的参数,config和resume
    torch.manual_seed(config["seed"])  # for both CPU and GPU #是在调用json文件中预定义的参数
    np.random.seed(config["seed"])
    #2. torch.manual_seed()是PyTorch库中的一个函数,设置PyTorch的随机数生成器种子
    #numpy随机种子

# 设置种子前 - 每次运行结果不同
# torch.rand(3)  # 第一次:tensor([0.4387, 0.0385, 0.9119])
# torch.rand(3)  # 第二次:tensor([0.1345, 0.7892, 0.6543])
# 设置种子后 - 每次运行结果相同
# torch.manual_seed(42)
# torch.rand(3)  # 第一次:tensor([0.8823, 0.9150, 0.3829])
# torch.manual_seed(42)
# torch.rand(3)  # 第二次:tensor([0.8823, 0.9150, 0.3829])  # 相同!

    train_dataloader = DataLoader(
        dataset=initialize_config(config["train_dataset"]),
        batch_size=config["train_dataloader"]["batch_size"],
        num_workers=config["train_dataloader"]["num_workers"],
        shuffle=config["train_dataloader"]["shuffle"],
        pin_memory=config["train_dataloader"]["pin_memory"]
    )

    valid_dataloader = DataLoader(
        dataset=initialize_config(config["validation_dataset"]),
        num_workers=1,
        batch_size=1
    )

    model = initialize_config(config["model"])

    optimizer = torch.optim.Adam(
        params=model.parameters(),
        lr=config["optimizer"]["lr"],
        betas=(config["optimizer"]["beta1"], config["optimizer"]["beta2"])
    )

    loss_function = initialize_config(config["loss_function"])

    trainer_class = initialize_config(config["trainer"], pass_args=False)

    trainer = trainer_class(
        config=config,
        resume=resume,
        model=model,
        loss_function=loss_function,
        optimizer=optimizer,
        train_dataloader=train_dataloader,
        validation_dataloader=valid_dataloader
    )

    trainer.train()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Wave-U-Net for Speech Enhancement")
    parser.add_argument("-C", "--configuration", required=True, type=str, help="Configuration (*.json).")
    parser.add_argument("-R", "--resume", action="store_true", help="Resume experiment from latest checkpoint.")
    args = parser.parse_args()

    configuration = json5.load(open(args.configuration))
    configuration["experiment_name"], _ = os.path.splitext(os.path.basename(args.configuration))
    configuration["config_path"] = args.configuration

    main(configuration, resume=args.resume)

下面是Unet:

import torch
import torch.nn as nn
import torch.nn.functional as F


class DownSamplingLayer(nn.Module):#定义一个类叫下采样模块,里面需要人为提供nn. Module,继承自 nn.Module
    def __init__(self, channel_in, channel_out, dilation=1, kernel_size=15, stride=1, padding=7):
#这个初始定义,是对类定义的,也就是当需要调用DownSamplingLayer()时,除了要输入channel_in, channel_out, 会自动带入dilation=1, kernel_size=15, stride=1, padding=7):
        super(DownSamplingLayer, self).__init__()#解释:调用父类 nn.Module 的初始化方法。这是Python继承的标准写法。
        self.main = nn.Sequential(
            nn.Conv1d(channel_in, channel_out, kernel_size=kernel_size,
                      stride=stride, padding=padding, dilation=dilation),
            nn.BatchNorm1d(channel_out),
            nn.LeakyReLU(negative_slope=0.1)
        )

    def forward(self, ipt):#ipt:输入数据
        return self.main(ipt)

class UpSamplingLayer(nn.Module):
    def __init__(self, channel_in, channel_out, kernel_size=5, stride=1, padding=2):#padding=5-1//2
        super(UpSamplingLayer, self).__init__()
        self.main = nn.Sequential(
            nn.Conv1d(channel_in, channel_out, kernel_size=kernel_size,
                      stride=stride, padding=padding),
            nn.BatchNorm1d(channel_out),
            nn.LeakyReLU(negative_slope=0.1, inplace=True),
        )

    def forward(self, ipt):
        return self.main(ipt)

class Model(nn.Module): 
    def __init__(self, n_layers=12, channels_interval=24):#这玩意就是后面需要用到什么固定参数就写什么,然后后面直接调用
        super(Model, self).__init__()
        self.n_layers = n_layers
        self.channels_interval = channels_interval
        encoder_in_channels_list = [1] + [i * self.channels_interval for i in range(1, self.n_layers)]#输入list,i*通道数,i为1—12,也就是1层24通道,2层48通道...
        encoder_out_channels_list = [i * self.channels_interval for i in range(1, self.n_layers + 1)]

        self.encoder = nn.ModuleList()
        for i in range(self.n_layers):#1-12层,每层都在encoder这个list后增添一个下采样层,其中每层输入,等于encoder_in_channels_list中层数对应的索引
            self.encoder.append(
                DownSamplingLayer(
                    channel_in=encoder_in_channels_list[i],
                    channel_out=encoder_out_channels_list[i]
                )
            )

        self.middle = nn.Sequential(
            nn.Conv1d(self.n_layers * self.channels_interval, self.n_layers * self.channels_interval, 15, stride=1,
                      padding=7),
            nn.BatchNorm1d(self.n_layers * self.channels_interval),
            nn.LeakyReLU(negative_slope=0.1, inplace=True)
        )

        decoder_in_channels_list = [(2 * i + 1) * self.channels_interval for i in range(1, self.n_layers)] + [
            2 * self.n_layers * self.channels_interval]
        decoder_in_channels_list = decoder_in_channels_list[::-1]
        decoder_out_channels_list = encoder_out_channels_list[::-1]
        self.decoder = nn.ModuleList()
        for i in range(self.n_layers):
            self.decoder.append(
                UpSamplingLayer(
                    channel_in=decoder_in_channels_list[i],
                    channel_out=decoder_out_channels_list[i]
                )
            )

        self.out = nn.Sequential(
            nn.Conv1d(1 + self.channels_interval, 1, kernel_size=1, stride=1),
            nn.Tanh()
        )
##上面,先定义了上下采样模块,然后设计了整个网络,encoder-decoder,以及非线性层
##下面是前向传播的全过程
    def forward(self, input):
        tmp = []#先初始化一个空的列表
        o = input #这我猜是用o字母表示input

        # Up Sampling
        for i in range(self.n_layers):#这里是循环,共12层,第一层:输入进去,经过encoder后填入tmp中;然后上一层的输入再经过encoder得到输出填到tmp ;...           o = self.encoder[i](o)
            tmp.append(o)
            # [batch_size, T // 2, channels]
            o = o[:, :, ::2]# # 时间维度下采样(取一半)
        o = self.middle(o) #经过encoder的输入再经过中间层

        # Down Sampling
        for i in range(self.n_layers):
            # [batch_size, T * 2, channels]
            o = F.interpolate(o, scale_factor=2, mode="linear", align_corners=True)
#F.interpolate是线性差值,是上采样的方法,也就是通过临近信息填补到中间空位的方式来增加信息量,从而提高分辨率
            # Skip Connection
#跳跃连接,cat 是 concatenate(连接)的缩写。就是把几个数组像粘胶水一样粘在一起。

            o = torch.cat([o, tmp[self.n_layers - i - 1]], dim=1)#这里意思是横着拼,把当前解码输入与对应编码输出拼到一起成为新的列表
            o = self.decoder[i](o)#新的列表作为输入到下一层解码,直到12层都结束,生成最终输出

        o = torch.cat([o, input], dim=1) 最终输出作为输入与最开始输入拼接得到最最终输入
        o = self.out(o) 作为网络输出
        return o

下面是对跳跃链接的一些解释:

import torch
# 有两块积木
积木A = torch.tensor([[1, 2], [3, 4]])  # 2×2
积木B = torch.tensor([[5, 6], [7, 8]])  # 2×2
# 横着拼(dim=0,按行拼)
横拼 = torch.cat([积木A, 积木B], dim=0)"""
[[1, 2],
 [3, 4],
 [5, 6],  ← 积木B接在下面
 [7, 8]]
形状: 4×2
"""
# 竖着拼(dim=1,按列拼)
竖拼 = torch.cat([积木A, 积木B], dim=1)"""
[[1, 2, 5, 6],  ← 积木B接在右边
 [3, 4, 7, 8]]
形状: 2×4
"""

posted @ 2025-12-19 13:39  barcode629  阅读(2)  评论(0)    收藏  举报