from __future__ import print_function
import os
import matplotlib as mpl
import tarfile
import matplotlib.image as mpimg
from matplotlib import pyplot as plt
import mxnet as mx
from mxnet import gluon
from mxnet import ndarray as nd
from mxnet.gluon import nn, utils
from mxnet import autograd
import numpy as np
epochs = 2 # Set low by default for tests, set higher when you actually run this code.
batch_size = 64
latent_z_size = 100
use_gpu = False
ctx = mx.gpu() if use_gpu else mx.cpu()
lr = 0.0002
beta1 = 0.5
data_path = 'lfw_dataset'
with tarfile.open("lfw-deepfunneled.tgz") as tar:
tar.extractall(path=data_path)
target_wd = 64
target_ht = 64
img_list = []
def transform(data, target_wd, target_ht):
# resize to target_wd * target_ht
data = mx.image.imresize(data, target_wd, target_ht)
# transpose from (target_wd, target_ht, 3)
# to (3, target_wd, target_ht)
data = nd.transpose(data, (2,0,1))
# normalize to [-1, 1]
data = data.astype(np.float32)/127.5 - 1
# if image is greyscale, repeat 3 times to get RGB image.
if data.shape[0] == 1:
data = nd.tile(data, (3, 1, 1))
return data.reshape((1,) + data.shape)
for path, _, fnames in os.walk(data_path):
for fname in fnames:
if not fname.endswith('.jpg'):
continue
img = os.path.join(path, fname)
img_arr = mx.image.imread(img)
img_arr = transform(img_arr, target_wd, target_ht)
img_list.append(img_arr)
train_data = mx.io.NDArrayIter(data=nd.concatenate(img_list), batch_size=batch_size)
def visualize(img_arr):
plt.imshow(((img_arr.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8))
plt.axis('off')
for i in range(4):
plt.subplot(1,4,i+1)
visualize(img_list[i + 10][0])
plt.show()
![]()
nc = 3
ngf = 64
netG = nn.Sequential()
with netG.name_scope():
# input is Z, going into a convolution
netG.add(nn.Conv2DTranspose(ngf * 8, 4, 1, 0, use_bias=False))
netG.add(nn.BatchNorm())
netG.add(nn.Activation('relu'))
# state size. (ngf*8) x 4 x 4
netG.add(nn.Conv2DTranspose(ngf * 4, 4, 2, 1, use_bias=False))
netG.add(nn.BatchNorm())
netG.add(nn.Activation('relu'))
# state size. (ngf*8) x 8 x 8
netG.add(nn.Conv2DTranspose(ngf * 2, 4, 2, 1, use_bias=False))
netG.add(nn.BatchNorm())
netG.add(nn.Activation('relu'))
# state size. (ngf*8) x 16 x 16
netG.add(nn.Conv2DTranspose(ngf, 4, 2, 1, use_bias=False))
netG.add(nn.BatchNorm())
netG.add(nn.Activation('relu'))
# state size. (ngf*8) x 32 x 32
netG.add(nn.Conv2DTranspose(nc, 4, 2, 1, use_bias=False))
netG.add(nn.Activation('tanh'))
# state size. (nc) x 64 x 64
# build the discriminator
ndf = 64
netD = nn.Sequential()
with netD.name_scope():
# input is (nc) x 64 x 64
netD.add(nn.Conv2D(ndf, 4, 2, 1, use_bias=False))
netD.add(nn.LeakyReLU(0.2))
# state size. (ndf) x 32 x 32
netD.add(nn.Conv2D(ndf * 2, 4, 2, 1, use_bias=False))
netD.add(nn.BatchNorm())
netD.add(nn.LeakyReLU(0.2))
# state size. (ndf) x 16 x 16
netD.add(nn.Conv2D(ndf * 4, 4, 2, 1, use_bias=False))
netD.add(nn.BatchNorm())
netD.add(nn.LeakyReLU(0.2))
# state size. (ndf) x 8 x 8
netD.add(nn.Conv2D(ndf * 8, 4, 2, 1, use_bias=False))
netD.add(nn.BatchNorm())
netD.add(nn.LeakyReLU(0.2))
# state size. (ndf) x 4 x 4
netD.add(nn.Conv2D(1, 4, 1, 0, use_bias=False))
# loss
loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
# initialize the generator and the discriminator
netG.initialize(mx.init.Normal(0.02), ctx=ctx)
netD.initialize(mx.init.Normal(0.02), ctx=ctx)
# trainer for the generator and the discriminator
trainerG = gluon.Trainer(netG.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
trainerD = gluon.Trainer(netD.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
from datetime import datetime
import time
import logging
real_label = nd.ones((batch_size,), ctx=ctx)
fake_label = nd.zeros((batch_size,),ctx=ctx)
def facc(label, pred):
pred = pred.ravel()
label = label.ravel()
return ((pred > 0.5) == label).mean()
metric = mx.metric.CustomMetric(facc)
stamp = datetime.now().strftime('%Y_%m_%d-%H_%M')
logging.basicConfig(level=logging.DEBUG)
for epoch in range(epochs):
tic = time.time()
btic = time.time()
train_data.reset()
iter = 0
for batch in train_data:
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
data = batch.data[0].as_in_context(ctx)
latent_z = mx.nd.random_normal(0, 1, shape=(batch_size, latent_z_size, 1, 1), ctx=ctx)
with autograd.record():
# train with real image
output = netD(data).reshape((-1, 1))
errD_real = loss(output, real_label)
metric.update([real_label,], [output,])
# train with fake image
fake = netG(latent_z)
output = netD(fake.detach()).reshape((-1, 1))
errD_fake = loss(output, fake_label)
errD = errD_real + errD_fake
errD.backward()
metric.update([fake_label,], [output,])
trainerD.step(batch.data[0].shape[0])
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
with autograd.record():
fake = netG(latent_z)
output = netD(fake).reshape((-1, 1))
errG = loss(output, real_label)
errG.backward()
trainerG.step(batch.data[0].shape[0])
# Print log infomation every ten batches
if iter % 10 == 0:
name, acc = metric.get()
logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))
logging.info('discriminator loss = %f, generator loss = %f, binary training acc = %f at iter %d epoch %d'
%(nd.mean(errD).asscalar(),
nd.mean(errG).asscalar(), acc, iter, epoch))
iter = iter + 1
btic = time.time()
name, acc = metric.get()
metric.reset()
# logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc))
# logging.info('time: %f' % (time.time() - tic))
# Visualize one generated image for each epoch
# fake_img = fake[0]
# visualize(fake_img)
# plt.show()
![]()
num_image = 12
latent_z = mx.nd.random_normal(0, 1, shape=(1, latent_z_size, 1, 1), ctx=ctx)
step = 0.05
for i in range(num_image):
img = netG(latent_z)
plt.subplot(3,4,i+1)
visualize(img[0])
latent_z += 0.05
plt.show()
![]()