Loading

【cv】cycleGAN代码解析:util包-gat_data.py

前言:关于util包

在深度学习模型项目中,util 文件夹(通常是 "utility" 的缩写)主要用于存放通用工具函数、辅助类或跨模块复用的功能代码,目的是减少代码冗余、提高复用性,并使项目结构更清晰。

其包含的文件通常围绕项目中多个模块(如训练、测试、数据处理、模型构建等)都会用到的共性功能展开,常见作用如下:

1. 数据处理工具

  • 提供通用的数据加载、解析、预处理函数(如归一化、标准化、数据增强的基础操作)。
  • 实现数据格式转换(如将 CSV 转为张量、将图片路径转为输入张量等)。
  • 定义数据集划分(如按比例拆分训练集/验证集)、批量生成(batch generation)的辅助逻辑。

2. 模型辅助工具

  • 模型保存与加载(如封装 torch.save()torch.load() 或 TensorFlow 的模型读写,处理路径、设备兼容等细节)。
  • 参数初始化(如自定义权重初始化方法,供多个模型共用)。
  • 模型结构辅助函数(如计算模型参数量、打印网络结构摘要)。

3. 训练/评估工具

  • 评估指标计算(如准确率、精确率、召回率、mIoU、BLEU 等,这些指标可能在训练、验证、测试阶段均会用到)。
  • 学习率调度器(如自定义学习率调整策略,可被不同训练脚本复用)。
  • 训练过程中的日志记录(如记录损失、指标变化,格式化输出到控制台或文件)。

4. 配置与路径工具

  • 配置文件解析(如解析 yaml/json 格式的超参数配置,提取训练轮数、学习率等参数,供整个项目使用)。
  • 路径管理(如创建输出目录、检查文件是否存在、拼接路径等,解决跨平台路径格式差异问题)。

5. 可视化工具

  • 训练曲线绘制(如损失、准确率随 epoch 变化的曲线)。
  • 特征图/注意力图可视化(通用的绘图逻辑,可被多个模型的可视化模块调用)。
  • 混淆矩阵、PR 曲线等评估结果的可视化函数。

6. 通用辅助功能

  • 日志配置(如初始化 logging 模块,设置日志级别、格式,供项目各模块统一调用)。
  • 时间/性能统计(如计算函数运行时间、GPU 内存占用监控)。
  • 异常处理(如自定义错误类型、通用的异常捕获与提示逻辑)。

总结

util 文件夹的核心是“通用性”——存放那些不特定属于某个模块(如单独的模型文件、训练脚本),但又被多个模块依赖的功能。通过将这些功能集中管理,既能避免重复编码,也能让项目的核心逻辑(如模型定义、训练流程)更简洁易懂。

from __future__ import print_function  # 兼容 Python2/3 的 print 函数语法(历史遗留,现代环境可不需要)
from pathlib import Path               # 用于跨平台地处理路径
import tarfile                         # 处理 .tar.gz 压缩包
import requests                        # 简单的 HTTP 请求库,用来下载网页与数据文件
from warnings import warn              # 发出非致命的警告信息
from zipfile import ZipFile            # 处理 .zip 压缩包
from bs4 import BeautifulSoup          # 解析 HTML,用于从数据页面抓取可下载的数据集文件名

class GetData(object):
    """A Python script for downloading CycleGAN or pix2pix datasets.
    Parameters:
        technique (str) -- One of: 'cyclegan' or 'pix2pix'.
        verbose (bool)  -- If True, print additional information.
    Examples:
        >>> from util.get_data import GetData
        >>> gd = GetData(technique='cyclegan')
        >>> new_data_path = gd.get(save_path='./datasets')  # options will be displayed.
    """
    # ↑ 类文档字符串:说明用途(下载 CycleGAN / pix2pix 数据集)、关键参数与使用示例

    def __init__(self, technique="cyclegan", verbose=True):
	
        # 下载页面的根 URL 字典(不同 technique 指向不同目录)
        url_dict = {
            "pix2pix": "http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/",
            "cyclegan": "http://efrosgans.eecs.berkeley.edu/pix2pix/datasets",
        }
	# 注意这里的url_dict字典,里面的键值对形式是:technique: 对应的数据集链接
	
        self.url = url_dict.get(technique.lower())  # 根据 technique 取对应 URL;未匹配则为 None
	# technique.lower()代表输入的模型不区分大小写
	
        self._verbose = verbose                     # 控制是否打印额外信息

    def _print(self, text):
        # 仅在 verbose=True 时打印信息(内部辅助函数)
        if self._verbose:
            print(text)

    @staticmethod
    def _get_options(r):
	# 和从网络上下载数据有关的方法
	# 核心作用是:从一个 HTTP 响应中解析并筛选出可下载的压缩文件选项(文件名)
	
        # 从 HTTP 响应 r 中解析可下载选项:
        # 1) 用 lxml 解析器构造 BeautifulSoup;
        # 2) 找出所有带 href 的 <a> 标签;
        # 3) 取其文本内容(链接文字),筛选以 ".zip" 或 ".tar.gz" 结尾的条目
        soup = BeautifulSoup(r.text, "lxml")
        options = [h.text for h in soup.find_all("a", href=True) if h.text.endswith((".zip", "tar.gz"))]
        return options

    def _present_options(self):
	# 把上一个方法中找到的所有可下载文件option列表呈现出来
	
        # 访问数据目录页面,列出可下载的压缩包文件名,提示用户输入编号选择其一
        r = requests.get(self.url)                 # GET 页面 HTML
        options = self._get_options(r)             # 解析出文件名列表
        print("Options:\n")
        for i, o in enumerate(options):
            print("{0}: {1}".format(i, o))         # 逐条打印:索引: 文件名
        choice = input("\nPlease enter the number of the " "dataset above you wish to download:")
        return options[int(choice)]                # 将用户输入视作整数索引并返回对应文件名

    def _download_data(self, dataset_url, save_path):
        # 负责真正的下载与解压逻辑
        save_path = Path(save_path)
        if not save_path.is_dir():
            save_path.mkdir(parents=True, exist_ok=True)  # 若保存目录不存在则递归创建

        base = Path(dataset_url).name           # 取出 URL 的最后一段(形如 xxx.zip / xxx.tar.gz)
        temp_save_path = save_path / base       # 临时下载到保存目录下的同名文件

        # 直接一次性下载到内存并写入磁盘(未使用流式下载;大文件时可能占用较多内存)
        with open(temp_save_path, "wb") as f:
            r = requests.get(dataset_url)
            f.write(r.content)

        # 根据后缀选择对应的解压对象
        if base.endswith(".tar.gz"):
            obj = tarfile.open(temp_save_path)
        elif base.endswith(".zip"):
            obj = ZipFile(temp_save_path, "r")
        else:
            raise ValueError("Unknown File Type: {0}.".format(base))  # 不支持的文件类型直接报错

        self._print("Unpacking Data...")  # 若 verbose,提示正在解压
        obj.extractall(save_path)         # 解压到保存目录
        obj.close()
        temp_save_path.unlink()           # 删除临时压缩包,仅保留解压后的内容

    def get(self, save_path, dataset=None):
	# 整合调用前面定义的几个方法,完成下载数据集的逻辑
        """
        Download a dataset.

        Parameters:
            save_path (str) -- A directory to save the data to.
            dataset (str)   -- (optional). A specific dataset to download.
                            Note: this must include the file extension.
                            If None, options will be presented for you
                            to choose from.

        Returns:
            save_path_full (str) -- the absolute path to the downloaded data.

        """
        # 若未指定具体数据集文件名(含后缀),则交互式展示选项;否则直接使用传入的 dataset
        if dataset is None:
            selected_dataset = self._present_options()
        else:
            selected_dataset = dataset

        # 解压后的目标目录:<save_path>/<压缩包主名(去掉扩展名)>
        save_path_full = Path(save_path) / selected_dataset.split(".")[0]

        if save_path_full.is_dir():
            # 若目录已存在则发出警告并跳过下载(避免覆盖已有数据)
            warn(f"\n'{save_path_full}' already exists. Voiding Download.")
        else:
            # 下载并解压
            self._print("Downloading Data...")
            url = f"{self.url}/{selected_dataset}"   # 拼接最终下载 URL(注意:即使 self.url 末尾带/,// 在 HTTP 中也可被接受)
            self._download_data(url, save_path=save_path)

        return save_path_full.resolve()  # 返回下载(或已存在)数据目录的绝对路径

posted @ 2025-09-25 15:15  SaTsuki26681534  阅读(14)  评论(0)    收藏  举报