# 图像风格迁移（Pytorch）

#### 图像风格迁移

##### Content Loss

$$l_{content} = \frac{1}{2}\sum (C_c-T_c)^2$$

##### Style Loss

$$l_{style}=\sum wi(Ts-Ss)^2$$

##### 总的损失函数

$$L_{total(S,C,T)}=\alpha l_{content}(C,T)+\beta L_{style}(S,T)$$

##### 代码
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.optim as optim
from torchvision import transforms, models

vgg = models.vgg19(pretrained=True).features	#使用预训练的VGG19，features表示只提取不包括全连接层的部分

for i in vgg.parameters():


def load_img(path, max_size=400,shape=None):
img = Image.open(path).convert('RGB')

if(max(img.size)) > max_size:	#规定图像的最大尺寸
size = max_size
else:
size = max(img.size)

if shape is not None:
size = shape
transform = transforms.Compose([
transforms.Resize(size),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))
])
'''删除alpha通道(jpg)， 转为png，补足另一个维度-batch'''
img = transform(img)[:3,:,:].unsqueeze(0)
return img


content  = load_img('./images/turtle.jpg')

'''转换为plt可以画出来的形式'''
def im_convert(tensor):
img = tensor.clone().detach()
img = img.numpy().squeeze()
img = img.transpose(1,2,0)
img = img * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
img = img.clip(0,1)
return img


def get_features(img, model, layers=None):
'''获取特征层'''
if layers is None:
layers = {
'0':'conv1_1',
'5':'conv2_1',
'10':'conv3_1',
'19':'conv4_1',
'21':'conv4_2',    #content层
'28':'conv5_1'
}

features = {}
x = img
for name, layer in model._modules.items():
x = layer(x)
if name in layers:
features[layers[name]] = x

return features

def gram_matrix(tensor):
'''计算Gram matrix'''
_, d, h, w = tensor.size()  #第一个是batch_size

tensor = tensor.view(d, h*w)

gram = torch.mm(tensor, tensor.t())

return gram

content_features = get_features(content, vgg)
style_features = get_features(style, vgg)

style_grams = {layer:gram_matrix(style_features[layer]) for layer in style_features}

'''定义不同层的权重'''
style_weights = {
'conv1_1': 1,
'conv2_1': 0.8,
'conv3_1': 0.5,
'conv4_1': 0.3,
'conv5_1': 0.1,
}
'''定义2种损失对应的权重'''
content_weight = 1
style_weight = 1e6


show_every = 400
steps = 2000

for ii in range(steps):
target_features = get_features(target, vgg)

content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)
style_loss = 0
'''加上每一层的gram_matrix矩阵的损失'''
for layer in style_weights:
target_feature = target_features[layer]
target_gram = gram_matrix(target_feature)
_, d, h, w = target_feature.shape
style_gram = style_grams[layer]
layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
style_loss += layer_style_loss/(d*h*w)     #加到总的style_loss里，除以大小

total_loss = content_weight * content_loss + style_weight * style_loss

total_loss.backward()
optimizer.step()

if ii % show_every == 0 :
print('Total Loss:',total_loss.item())
plt.imshow(im_convert(target))
plt.show()


##### 参考：
1. Image Style Transfer Using Convolutional Neural Networks论文
2. Udacity——PyTorch Scholarship Challenge
posted @ 2019-03-21 16:24  MartinLwx  阅读(4757)  评论(2编辑  收藏  举报