(13)tensorflow数据集操作

经典数据集操作

功能函数代码
加载数据集datasets.Dataset_name.load_data()
构建 Dataset 对象tf.data.Dataset_name.from_tensor_slices((x, y))
随机打散Dataset_name.shuffle(buffer_size)
批训练Dataset_name.batch(size)
数据预处理Dataset_name.map(func_name)
数据集Datatset_name类型
Boston housing波士顿房价趋势
CIFAR10/100图片数据集
MNIST/Fashion_MNIST手写数字
IMDB文本分类

数据集缓存在用户目录下的.keras/datasets 文件夹

加载数据集

数据集缓存在用户目录下的.keras/datasets 文件夹(有则加载,无则自动下载)

import tensorflow as tf
from tensorflow.keras import datasets
(x,y),(x_text,y_text) =  datasets.mnist.load_data()
print(x.shape)
print(y.shape)
print(x_text.shape)
print(y_text.shape)

out:
(60000, 28, 28)
(60000,)
(10000, 28, 28)
(10000,)

数据加载进入内存后,需要转换成 Dataset 对象, 才能利用 TensorFlow 提供的各种操作

import tensorflow as tf
from tensorflow.keras import datasets
(x,y),(x_text,y_text) =  datasets.mnist.load_data()
print(x.shape)
print(y.shape)
print(x_text.shape)
print(y_text.shape)
train_db = tf.data.Dataset.from_tensor_slices((x, y))
print(train_db)

out:
(60000, 28, 28)
(60000,)
(10000, 28, 28)
(10000,)
<TensorSliceDataset shapes: ((28, 28), ()), types: (tf.uint8, tf.uint8)>

随机打散

  • Dataset_name.shuffle(buffer_size)
  • buffer_size为缓冲池大小,设置一个较大常数
import tensorflow as tf
from tensorflow.keras import datasets
(x,y),(x_text,y_text) =  datasets.mnist.load_data()
print(x.shape)
print(y.shape)
print(x_text.shape)
print(y_text.shape)
train_db = tf.data.Dataset.from_tensor_slices((x, y))
td = train_db.shuffle(500)
print(td)

out:
(60000, 28, 28)
(60000,)
(10000, 28, 28)
(10000,)
<ShuffleDataset shapes: ((28, 28), ()), types: (tf.uint8, tf.uint8)>

批训练

  • Dataset_name.batch(size)
  • 同时并行计算多个样本为批训练,size即为并行计算数目,尽量根据显卡性能配置
import tensorflow as tf
from tensorflow.keras import datasets
(x,y),(x_text,y_text) =  datasets.mnist.load_data()
train_db = tf.data.Dataset.from_tensor_slices((x, y))
train_db = train_db.batch(100)
print(train_db)

out:
<BatchDataset shapes: ((None, 28, 28), (None,)), types: (tf.uint8, tf.uint8)>

预处理

  • Dataset_name.map(func_name)
import tensorflow as tf
from tensorflow.keras import datasets
(x,y),(x_text,y_text) =  datasets.mnist.load_data()
train_db = tf.data.Dataset.from_tensor_slices((x, y))
def func_name(x,y):
    x = tf.cast(x, dtype=tf.float32) / 255.
    x = tf.reshape(x, [-1, 28 * 28])
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)
    return x , y
train_db = train_db.map(func_name)
print(train_db)

out:
<MapDataset shapes: ((1, 784), (10,)), types: (tf.float32, tf.float32)>

循环训练

  •   for step, (x,y) in enumerate(train_db):
    
  •   for x,y in train_db:
    
  •   for epoch in range(20):
      	for step, (x,y) in enumerate(train_db):
    
  •   train_db = train_db.repeat(20)
    
posted @ 2020-09-04 09:16  kuanleung  阅读(48)  评论(0)    收藏  举报  来源