神经网络图像压缩代码

参加第五届人工智能竞赛,选的图像编码赛道(钱多),纯记录下,这神经网络结构打榜分数也不高,我觉得重要的在于找到一种合适于图像压缩任务的结构,训练倒是其次,主办方让完全采用AI的方式去做,我觉得在网络结构的选取上,势必要加入一些自己对图像的专业理解的,只是这种理解不能以传统的方式表现出来。

点击查看代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from PIL import Image
import numpy as np
import lpips
from pytorch_msssim import ms_ssim
from torchvision.transforms.functional import normalize

class ImageCompressor(nn.Module):
    def __init__(self, compression_ratio=8):
        super(ImageCompressor, self).__init__()
        
        # 编码器 - 减少通道数以提高压缩率
        # 编码器部分
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(16, 4, kernel_size=3, padding=1),
            nn.ReLU()
        )
        
        # 解码器部分相应修改
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(4, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.ConvTranspose2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.ConvTranspose2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.ConvTranspose2d(64, 3, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

def compress_image(image_path, output_path, model_path=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ImageCompressor().to(device)
    
    # 定义多个损失函数
    mse_criterion = nn.MSELoss()  # MSE损失(用于PSNR)
    lpips_criterion = lpips.LPIPS(net='alex').to(device)  # LPIPS损失
    
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # 准备数据
    transform = transforms.Compose([
        transforms.ToTensor(),
        # 添加归一化处理
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    img = Image.open(image_path)
    img_tensor = transform(img).unsqueeze(0).to(device)
    
    # 训练模型
    model.train()
    for epoch in range(300):
        optimizer.zero_grad()
        output = model(img_tensor)
        
        # 计算多个损失
        # 1. MSE损失(用于PSNR)
        mse_loss = mse_criterion(output, img_tensor)
        psnr = -10 * torch.log10(mse_loss)
        
        # 2. MS-SSIM损失
        ms_ssim_loss = 1 - ms_ssim(output, img_tensor, data_range=1.0)
        
        # 3. LPIPS损失
        lpips_loss = lpips_criterion(output, img_tensor).mean()
        
        # 组合损失,使用权重平衡各项
        total_loss = (
            1.0 * mse_loss +        # 基础重建损失
            0.3 * ms_ssim_loss +    # 结构相似性损失
            0.1 * lpips_loss        # 感知损失
        )
        
        total_loss.backward()
        optimizer.step()
        
        if (epoch + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/300]')
            print(f'PSNR: {psnr.item():.2f}')
            print(f'MS-SSIM Loss: {ms_ssim_loss.item():.4f}')
            print(f'LPIPS: {lpips_loss.item():.4f}')
            print('------------------------')

    # 保存模型
    if model_path:
        torch.save(model.state_dict(), model_path)
    
    # 压缩图像
    model.eval()
    with torch.no_grad():
        img = Image.open(image_path)
        img_tensor = transforms.ToTensor()(img).unsqueeze(0).to(device)
        compressed = model(img_tensor)
        
        # 将结果转换回图像
        output_img = transforms.ToPILImage()(compressed.squeeze(0).cpu())
        output_img.save(output_path)

def decompress_image(compressed_path, output_path, model_path):
    # 加载模型
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ImageCompressor().to(device)
    
    if not os.path.exists(model_path):
        raise ValueError("未找到模型文件!")
    
    model.load_state_dict(torch.load(model_path))
    model.eval()
    
    # 解压缩图像
    with torch.no_grad():
        img = Image.open(compressed_path)
        img_tensor = transforms.ToTensor()(img).unsqueeze(0).to(device)
        decompressed = model(img_tensor)
        
        # 将结果转换回图像
        output_img = transforms.ToPILImage()(decompressed.squeeze(0).cpu())
        output_img.save(output_path)

if __name__ == "__main__":
    import os
    
    # 示例使用
    input_path = "input.jpg"
    compressed_path = "compressed.jpg"
    decompressed_path = "decompressed.jpg"
    model_path = "compressor_model.pth"
    
    # 压缩流程
    print("正在压缩图像...")
    compress_image(input_path, compressed_path, model_path)
    print(f"压缩完成,已保存至 {compressed_path}")
    
    # 解压缩流程
    print("正在解压缩图像...")
    decompress_image(compressed_path, decompressed_path, model_path)
    print(f"解压缩完成,已保存至 {decompressed_path}")
目前效果如下:
点击查看代码
Epoch [300/300]
PSNR: 10.01
MS-SSIM Loss: 0.3211
LPIPS: 0.2903
posted @ 2025-02-27 06:03  我是个rapper喔  阅读(45)  评论(0)    收藏  举报