AI Toolkit - 一站式扩散模型训练套件

项目标题与描述

AI Toolkit 是由 Ostris 开发的一站式扩散模型训练套件,旨在支持消费级硬件上的最新图像和视频扩散模型训练。该项目既可作为图形界面(GUI)也可作为命令行工具(CLI)运行,设计理念是简单易用但功能全面。

核心价值:

  • 支持多种扩散模型架构(SD、SDXL、Flux等)
  • 提供完整的训练流程(LoRA、Dreambooth、微调等)
  • 优化的消费级硬件支持
  • 丰富的扩展功能(数据集工具、标签生成等)

功能特性

  • 模型训练支持

    • LoRA训练
    • Dreambooth训练
    • 文本反转训练
    • 模型微调
    • 参考图像滑动训练
  • 模型架构支持

    • Stable Diffusion 1.5/2.x
    • SDXL
    • Flux/Flex系列
    • HiDream
    • OmniGen2
    • Qwen Image
  • 实用工具

    • 数据集同步工具(Unsplash/Pexels)
    • 超级标签生成器(LLaVA/Fuyu)
    • 模型转换工具
    • 图像生成工具
  • 高级功能

    • 混合精度训练
    • 梯度累积
    • EMA模型平均
    • 学习率调度
    • 自定义损失函数

安装指南

系统要求

  • NVIDIA GPU (24GB VRAM 推荐)
  • Python 3.10+
  • CUDA 11.8
  • cuDNN 8.6+

安装步骤

  1. 克隆仓库:
git clone https://github.com/ostris/ai-toolkit.git
cd ai-toolkit
  1. 创建虚拟环境:
python -m venv venv
source venv/bin/activate  # Linux/Mac
venv\Scripts\activate  # Windows
  1. 安装依赖:
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
pip install -r requirements.txt
  1. 安装额外依赖(可选):
pip install git+https://github.com/haotian-liu/LLaVA.git

使用说明

基础训练示例

from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.config_modules import ModelConfig

# 初始化模型
model_config = ModelConfig(
    name_or_path="runwayml/stable-diffusion-v1-5",
    is_v2=False,
    dtype="fp16"
)

sd = StableDiffusion(
    device="cuda",
    model_config=model_config,
    dtype="fp16"
)

# 加载模型
sd.load_model()

# 训练配置
train_config = {
    "learning_rate": 1e-5,
    "max_train_steps": 1000,
    "train_batch_size": 4
}

# 开始训练
sd.train(train_config)

图像生成示例

from toolkit.config_modules import GenerateImageConfig

gen_config = GenerateImageConfig(
    prompts=["A beautiful sunset over mountains"],
    width=512,
    height=512,
    guidance_scale=7.5,
    num_inference_steps=50
)

images = sd.generate_images(gen_config)
images[0].save("sunset.png")

核心代码

模型加载核心代码

def load_model(self):
    """加载扩散模型组件"""
    # 加载文本编码器
    self.text_encoder = CLIPTextModel.from_pretrained(
        self.model_config.name_or_path,
        subfolder="text_encoder",
        torch_dtype=self.torch_dtype
    )
    
    # 加载VAE
    self.vae = AutoencoderKL.from_pretrained(
        self.model_config.name_or_path,
        subfolder="vae",
        torch_dtype=self.torch_dtype
    )
    
    # 加载UNet
    self.unet = UNet2DConditionModel.from_pretrained(
        self.model_config.name_or_path,
        subfolder="unet",
        torch_dtype=self.torch_dtype
    )

训练循环核心代码

def train_loop(self, dataloader, optimizer, lr_scheduler):
    """训练主循环"""
    for epoch in range(self.train_config.num_epochs):
        for batch in dataloader:
            # 前向传播
            latents = self.vae.encode(batch["pixel_values"]).latent_dist.sample()
            noise = torch.randn_like(latents)
            timesteps = torch.randint(
                0, self.noise_scheduler.num_train_timesteps, 
                (latents.shape[0],), device=self.device
            )
            
            noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
            
            # 获取文本嵌入
            text_embeddings = self.text_encoder(batch["input_ids"])[0]
            
            # 预测噪声
            noise_pred = self.unet(
                noisy_latents, timesteps, text_embeddings
            ).sample
            
            # 计算损失
            loss = F.mse_loss(noise_pred, noise)
            
            # 反向传播
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

更多精彩内容 请关注我的个人公众号 公众号(办公AI智能小助手)
公众号二维码

posted @ 2025-08-19 10:56  qife  阅读(27)  评论(0)    收藏  举报