keras 学习笔记(二) ——— data_generator
data_generator
每次输出一个batch,基于keras.utils.Sequence
Base object for fitting to a sequence of data, such as a dataset.
Every
Sequencemust implement the__getitem__and the__len__methods. If you want to modify your dataset between epochs you may implementon_epoch_end. The method__getitem__should return a complete batch.Notes
Sequenceare a safer way to do multiprocessing. This structure guarantees that the network will only train once on each sample per epoch which is not the case with generators.
Sequence example: https://keras.io/utils/#sequence
#!/usr/bin/env python
# coding: utf-8
from keras.utils import Sequence
import numpy as np
from keras.preprocessing import image
from skimage.io import imread
class My_Custom_Generator(Sequence) :
    def __init__(self, image_filenames, labels, batch_size) :
        self.image_filenames = image_filenames
        self.labels = labels
        self.batch_size = batch_size
    def __len__(self) :
        return (np.ceil(len(self.image_filenames) / float(self.batch_size))).astype(np.int)
    def __getitem__(self, idx) :
        batch_y = self.labels[idx * self.batch_size : (idx+1) * self.batch_size]
        batch_x = self.image_filenames[idx * self.batch_size : (idx+1) * self.batch_size]
        batch_seq  = []  #batch_seq
        for x in batch_x:  #len(x) =16
            seq_img = []
            for img in x: #len(item) =25
                seq_img.append(image.img_to_array(imread(img)))
            seq_x = np.array([seq_img])
            batch_seq.append(seq_img)
        batch_seq_list = np.array(batch_seq)
        return batch_seq_list, np.array(batch_y)
两种将数据输出为numpy.array的方法
通过list转为numpy.array
速度快,list转array过程需要注意数据维度变化
''' list
batch_x =X_train_filenames[idx * batch_size : (idx+1) * batch_size]
batch_seq  = []  #batch_seq
for x in batch_x:  #len(x) =16
    seq_img = []
    for img in x: #len(item) =25
        seq_img.append(image.img_to_array(imread(img)))
    seq_x = np.array([seq_img])
    batch_seq.append(seq_img)
batch_seq_list = np.array(batch_seq)
'''
利用np.empty
速度慢,开始前确定batch维度即可
'''numpy
batch_x =X_train_filenames[idx * batch_size : (idx+1) * batch_size]
batch_seq  = np.empty((0,25,224,224,3),float)
for x in batch_x:  #len(x) =16
    seq_batch = np.empty((0,224,224,3),float)
    for item in x: #len(item) =25
        seq_batch = np.append(seq_batch, np.expand_dims(image.img_to_array(imread(item)), axis=0), axis = 0) 
    batch_seq2 = np.append(batch_seq, np.expand_dims((seq_batch), axis=0), axis = 0)
'''

 
                
            
         
         浙公网安备 33010602011771号
浙公网安备 33010602011771号