数据增强---Mixup
Mixup
什么是Mixup

假设我们在做猫狗分类的任务,label使用one-hot vector形式([1,0] → 狗,[0,1]→猫),mixup的输出将图像和label分别进行了加权融合。
如果用数学公式来表达的话:
\[x=\lambda x_i+(1-\lambda)x_j
\\
y=\lambda y_i+(1-\lambda)y_j \]
其中\(x,y\)是输出图像和label,\((x_i,y_i),(x_j,y_i)\)是输入的两种图像和label,\(\lambda\)一般是一个服从\(\beta\)分布的随机数。
这提供了处于两个类别(猫和狗)之前的一些样本,扩展了样本分布,让训练出的模型具有更强的健壮性。
pytorch实现
 
 
 
import cv2
import numpy as np
import torch
def mix_up(img1, label1, img2, label2, alpha=0.2):
    img1 = img1.astype("float")
    img2 = img2.astype("float")
    alpha = 0.2
    lambda_ = np.random.beta(alpha, alpha)
    mixed_image = lambda_ * img1 + (1. - lambda_) * img2
    mixed_label = np.multiply(lambda_, label1) + np.multiply((1. - lambda_), label2)
    return mixed_image, mixed_label
if __name__ == '__main__':
    dog = cv2.imread("data/inu.png")
    dog_label = [0., 1.]
    cat = cv2.imread("data/neko.png")
    cat_label = [1., 0.]
    mixed_img, mixed_label = mix_up(dog, dog_label, cat, cat_label)
    cv2.imshow('original', dog)
    cv2.waitKey()
    cv2.imshow('original', cat)
    cv2.waitKey()
    mixed_img = mixed_img.astype("uint8")
    cv2.imshow('original', mixed_img)
    cv2.waitKey()
    cv2.imwrite("data/mixed.png", mixed_img)
参考文章
https://towardsdatascience.com/enhancing-neural-networks-with-mixup-in-pytorch-5129d261bc4a
 
                    
                     
                    
                 
                    
                
 
 
                
            
         
         浙公网安备 33010602011771号
浙公网安备 33010602011771号