DCGAN
# -*- coding: UTF-8 -*-
import torch
import torch.nn as nn
import numpy as np
import torch.nn.init as init
import os
import test
from GAN_model import Generator,Discriminator
print("data loading ...")
G_LR=0.0002
D_LR=0.0002
BATCHSIZE=50
EPOCHES=3000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_pt="data51200.pt"###the position and name of train data(can get it by data_loader.py or data_loader_sketch.py)
train_para_save_path="./pkl/"
loss_save_file = 'loss.txt'
def init_ws_bs(m):
if isinstance(m,nn.ConvTranspose2d):
init.normal_(m.weight.data,std=0.2)
init.normal_(m.bias.data,std=0.2)
g=Generator().to(device)
d=Discriminator().to(device)
init_ws_bs(g),init_ws_bs(d)
###load traind model
# para_path="./pkl/"
# para_file="29g.pkl"
# g=torch.load(para_path+para_file)
# d=torch.load(para_path+para_file)
g_optimizer=torch.optim.Adam(g.parameters(),betas=(.5,0.999),lr=G_LR)
d_optimizer=torch.optim.Adam(d.parameters(),betas=(.5,0.999),lr=D_LR)
g_loss_func=nn.BCELoss()
d_loss_func=nn.BCELoss()
label_real = torch.ones(BATCHSIZE).to(device)
label_fake = torch.zeros(BATCHSIZE).to(device)
if os.path.exists(loss_save_file):
os.remove(loss_save_file)
if os.path.exists(data_pt):
real_img=torch.load(data_pt)
if real_img !=None:
print("load data successfully")
else:
print("fail to load data")
if not os.path.exists(train_para_save_path):
os.makedirs(train_para_save_path)
for file in os.listdir(train_para_save_path):
os.remove(train_para_save_path + file)
print("start training")
batch_imgs=[]
for epoch in range(EPOCHES):
np.random.shuffle(real_img)
loss_epoch=[]
for i in range(len(real_img)):
batch_imgs.append(real_img[i].numpy())
if (i+1) % BATCHSIZE == 0:
batch_real=torch.Tensor(batch_imgs).to(device)
batch_imgs.clear()
####min Discriminate loss
d_optimizer.zero_grad()
pre_real=d(batch_real).squeeze()
# pre_real = d(batch_real)
d_real_loss=d_loss_func(pre_real,label_real)
d_real_loss.backward()
batch_fake=torch.randn(BATCHSIZE,100,1,1).to(device)
img_fake=g(batch_fake)
pre_fake=d(img_fake.detach()).squeeze()
d_fake_loss=d_loss_func(pre_fake,label_fake)
d_fake_loss.backward()
d_optimizer.step()
####min Generate loss
g_optimizer.zero_grad()
batch_fake=torch.randn(BATCHSIZE,100,1,1).to(device)
img_fake=g(batch_fake)
pre_fake=d(img_fake).squeeze()
g_loss=g_loss_func(pre_fake,label_real)
g_loss.backward()
g_optimizer.step()
batch_num=i/BATCHSIZE
print("epoch%d batch%d:"%(epoch,batch_num),(d_real_loss+d_fake_loss).detach().cpu().numpy(),g_loss.detach().cpu().numpy())
loss_epoch.append([(d_real_loss+d_fake_loss).detach().cpu().numpy(),g_loss.detach().cpu().numpy()])
###After finishing an epoch,record the data
torch.save(g,train_para_save_path+str(epoch)+"g.pkl")
torch.save(d,train_para_save_path+str(epoch)+"d.pkl")
with open(loss_save_file, 'a+') as f:
for d_loss_epoch,g_loss_epoch in loss_epoch:
f.write(str(d_loss_epoch)+' '+str(g_loss_epoch)+'\n')
test.draw(train_para_save_path+str(epoch)+"g.pkl",str(epoch))
print("finish the train")
GAN_model.py
import torch.nn as nn class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.deconv1 = nn.Sequential(#batchsize,100,1,1 nn.ConvTranspose2d( # stride(input_w-1)+k-2*Padding in_channels=100, out_channels=64 * 8, kernel_size=4, stride=1, padding=0, bias=False, ), nn.BatchNorm2d(64 * 8), nn.ReLU(inplace=True), ) # 14 self.deconv2 = nn.Sequential( nn.ConvTranspose2d( # stride(input_w-1)+k-2*Padding in_channels=64 * 8, out_channels=64 * 4, kernel_size=4, stride=2, padding=1, bias=False, ), nn.BatchNorm2d(64 * 4), nn.ReLU(inplace=True), ) # 24 self.deconv3 = nn.Sequential( nn.ConvTranspose2d( # stride(input_w-1)+k-2*Padding in_channels=64 * 4, out_channels=64 * 2, kernel_size=4, stride=2, padding=1, bias=False, ), nn.BatchNorm2d(64 * 2), nn.ReLU(inplace=True), ) # 48 self.deconv4 = nn.Sequential( nn.ConvTranspose2d( # stride(input_w-1)+k-2*Padding in_channels=64 * 2, out_channels=64 * 1, kernel_size=4, stride=2, padding=1, bias=False, ), nn.BatchNorm2d(64), nn.ReLU(inplace=True), ) self.deconv5 = nn.Sequential( nn.ConvTranspose2d(64, 3, 5, 3, 1, bias=False), nn.Tanh(), ) def forward(self, x): x = self.deconv1(x) x = self.deconv2(x) x = self.deconv3(x) x = self.deconv4(x) x = self.deconv5(x) return x class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.conv1 = nn.Sequential( nn.Conv2d( # batchsize,3,96,96 in_channels=3, out_channels=64, kernel_size=5, padding=1, stride=3, bias=False, ), nn.BatchNorm2d(64), nn.LeakyReLU(.2, inplace=True), ) self.conv2 = nn.Sequential( nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False, ), # batchsize,16,32,32 nn.BatchNorm2d(64 * 2), nn.LeakyReLU(.2, inplace=True), ) self.conv3 = nn.Sequential( nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(64 * 4), nn.LeakyReLU(.2, inplace=True), ) self.conv4 = nn.Sequential( nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False), nn.BatchNorm2d(64 * 8), nn.LeakyReLU(.2, inplace=True), ) self.output = nn.Sequential( nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False), nn.Sigmoid() # ) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.conv4(x) x = self.output(x) return x
GAN的精髓在于对抗。生成损失和对抗损失的网络反向传播的方式是一样的,只不过生成损失只更新生成器的参数,判别损失只更新判别器的参数(在优化器里面定义)。
生成器的训练目标只有一个,让生成的假的图片更像真的:g_loss=g_loss_func(pre_fake,label_real)
而判别器的目标有两个,让真的更像真的:d_real_loss=d_loss_func(pre_real,label_real)
让假的更像假的:d_fake_loss=d_loss_func(pre_fake,label_fake)
浙公网安备 33010602011771号