【cv】cycleGAN代码解析:util包-gat_data.py
前言:关于util包
在深度学习模型项目中,util 文件夹(通常是 "utility" 的缩写)主要用于存放通用工具函数、辅助类或跨模块复用的功能代码,目的是减少代码冗余、提高复用性,并使项目结构更清晰。
其包含的文件通常围绕项目中多个模块(如训练、测试、数据处理、模型构建等)都会用到的共性功能展开,常见作用如下:
1. 数据处理工具
- 提供通用的数据加载、解析、预处理函数(如归一化、标准化、数据增强的基础操作)。
- 实现数据格式转换(如将 CSV 转为张量、将图片路径转为输入张量等)。
- 定义数据集划分(如按比例拆分训练集/验证集)、批量生成(batch generation)的辅助逻辑。
2. 模型辅助工具
- 模型保存与加载(如封装
torch.save()、torch.load()或 TensorFlow 的模型读写,处理路径、设备兼容等细节)。 - 参数初始化(如自定义权重初始化方法,供多个模型共用)。
- 模型结构辅助函数(如计算模型参数量、打印网络结构摘要)。
3. 训练/评估工具
- 评估指标计算(如准确率、精确率、召回率、mIoU、BLEU 等,这些指标可能在训练、验证、测试阶段均会用到)。
- 学习率调度器(如自定义学习率调整策略,可被不同训练脚本复用)。
- 训练过程中的日志记录(如记录损失、指标变化,格式化输出到控制台或文件)。
4. 配置与路径工具
- 配置文件解析(如解析 yaml/json 格式的超参数配置,提取训练轮数、学习率等参数,供整个项目使用)。
- 路径管理(如创建输出目录、检查文件是否存在、拼接路径等,解决跨平台路径格式差异问题)。
5. 可视化工具
- 训练曲线绘制(如损失、准确率随 epoch 变化的曲线)。
- 特征图/注意力图可视化(通用的绘图逻辑,可被多个模型的可视化模块调用)。
- 混淆矩阵、PR 曲线等评估结果的可视化函数。
6. 通用辅助功能
- 日志配置(如初始化 logging 模块,设置日志级别、格式,供项目各模块统一调用)。
- 时间/性能统计(如计算函数运行时间、GPU 内存占用监控)。
- 异常处理(如自定义错误类型、通用的异常捕获与提示逻辑)。
总结
util 文件夹的核心是“通用性”——存放那些不特定属于某个模块(如单独的模型文件、训练脚本),但又被多个模块依赖的功能。通过将这些功能集中管理,既能避免重复编码,也能让项目的核心逻辑(如模型定义、训练流程)更简洁易懂。
from __future__ import print_function # 兼容 Python2/3 的 print 函数语法(历史遗留,现代环境可不需要)
from pathlib import Path # 用于跨平台地处理路径
import tarfile # 处理 .tar.gz 压缩包
import requests # 简单的 HTTP 请求库,用来下载网页与数据文件
from warnings import warn # 发出非致命的警告信息
from zipfile import ZipFile # 处理 .zip 压缩包
from bs4 import BeautifulSoup # 解析 HTML,用于从数据页面抓取可下载的数据集文件名
class GetData(object):
"""A Python script for downloading CycleGAN or pix2pix datasets.
Parameters:
technique (str) -- One of: 'cyclegan' or 'pix2pix'.
verbose (bool) -- If True, print additional information.
Examples:
>>> from util.get_data import GetData
>>> gd = GetData(technique='cyclegan')
>>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
"""
# ↑ 类文档字符串:说明用途(下载 CycleGAN / pix2pix 数据集)、关键参数与使用示例
def __init__(self, technique="cyclegan", verbose=True):
# 下载页面的根 URL 字典(不同 technique 指向不同目录)
url_dict = {
"pix2pix": "http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/",
"cyclegan": "http://efrosgans.eecs.berkeley.edu/pix2pix/datasets",
}
# 注意这里的url_dict字典,里面的键值对形式是:technique: 对应的数据集链接
self.url = url_dict.get(technique.lower()) # 根据 technique 取对应 URL;未匹配则为 None
# technique.lower()代表输入的模型不区分大小写
self._verbose = verbose # 控制是否打印额外信息
def _print(self, text):
# 仅在 verbose=True 时打印信息(内部辅助函数)
if self._verbose:
print(text)
@staticmethod
def _get_options(r):
# 和从网络上下载数据有关的方法
# 核心作用是:从一个 HTTP 响应中解析并筛选出可下载的压缩文件选项(文件名)
# 从 HTTP 响应 r 中解析可下载选项:
# 1) 用 lxml 解析器构造 BeautifulSoup;
# 2) 找出所有带 href 的 <a> 标签;
# 3) 取其文本内容(链接文字),筛选以 ".zip" 或 ".tar.gz" 结尾的条目
soup = BeautifulSoup(r.text, "lxml")
options = [h.text for h in soup.find_all("a", href=True) if h.text.endswith((".zip", "tar.gz"))]
return options
def _present_options(self):
# 把上一个方法中找到的所有可下载文件option列表呈现出来
# 访问数据目录页面,列出可下载的压缩包文件名,提示用户输入编号选择其一
r = requests.get(self.url) # GET 页面 HTML
options = self._get_options(r) # 解析出文件名列表
print("Options:\n")
for i, o in enumerate(options):
print("{0}: {1}".format(i, o)) # 逐条打印:索引: 文件名
choice = input("\nPlease enter the number of the " "dataset above you wish to download:")
return options[int(choice)] # 将用户输入视作整数索引并返回对应文件名
def _download_data(self, dataset_url, save_path):
# 负责真正的下载与解压逻辑
save_path = Path(save_path)
if not save_path.is_dir():
save_path.mkdir(parents=True, exist_ok=True) # 若保存目录不存在则递归创建
base = Path(dataset_url).name # 取出 URL 的最后一段(形如 xxx.zip / xxx.tar.gz)
temp_save_path = save_path / base # 临时下载到保存目录下的同名文件
# 直接一次性下载到内存并写入磁盘(未使用流式下载;大文件时可能占用较多内存)
with open(temp_save_path, "wb") as f:
r = requests.get(dataset_url)
f.write(r.content)
# 根据后缀选择对应的解压对象
if base.endswith(".tar.gz"):
obj = tarfile.open(temp_save_path)
elif base.endswith(".zip"):
obj = ZipFile(temp_save_path, "r")
else:
raise ValueError("Unknown File Type: {0}.".format(base)) # 不支持的文件类型直接报错
self._print("Unpacking Data...") # 若 verbose,提示正在解压
obj.extractall(save_path) # 解压到保存目录
obj.close()
temp_save_path.unlink() # 删除临时压缩包,仅保留解压后的内容
def get(self, save_path, dataset=None):
# 整合调用前面定义的几个方法,完成下载数据集的逻辑
"""
Download a dataset.
Parameters:
save_path (str) -- A directory to save the data to.
dataset (str) -- (optional). A specific dataset to download.
Note: this must include the file extension.
If None, options will be presented for you
to choose from.
Returns:
save_path_full (str) -- the absolute path to the downloaded data.
"""
# 若未指定具体数据集文件名(含后缀),则交互式展示选项;否则直接使用传入的 dataset
if dataset is None:
selected_dataset = self._present_options()
else:
selected_dataset = dataset
# 解压后的目标目录:<save_path>/<压缩包主名(去掉扩展名)>
save_path_full = Path(save_path) / selected_dataset.split(".")[0]
if save_path_full.is_dir():
# 若目录已存在则发出警告并跳过下载(避免覆盖已有数据)
warn(f"\n'{save_path_full}' already exists. Voiding Download.")
else:
# 下载并解压
self._print("Downloading Data...")
url = f"{self.url}/{selected_dataset}" # 拼接最终下载 URL(注意:即使 self.url 末尾带/,// 在 HTTP 中也可被接受)
self._download_data(url, save_path=save_path)
return save_path_full.resolve() # 返回下载(或已存在)数据目录的绝对路径

浙公网安备 33010602011771号