数据加载
outline
- keras.datasets
- tf.data.Dataset.from_tensor_slices
- shuffle
- map
- batch
- repeat
获取数据集
In [5]: (x,y), (x_ _test,y. _test)= keras. datasets.mnist.load_ data()
In [6]: x.shape #(60000, 28, 28)
In [7]: y.shape #(60000, )
In [8]: x.min(),x.max(),x.mean()
Out[8]: (0, 255, 33.318421449829934)
In [9]: X_ test.shape, y_ test.shape
Out[9]: ((10000, 28, 28), (10000,))
In [10]: y[:4]
0ut[10]: array([5, 0,4,1], dtype=uint8)
In [11]: y_onehot=tf .one_ hot(y, depth=10)
In [12]: y_onehot[:2]
<tf.Tensor: id=8, shape=(2, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.,0.,0., 0., 0., 0.]], dtype=float32)>
CIFAR10/100
In [14]: (x,y), (x_test,y_test)=keras.datasets.cifar10.load_ data()
In [15]: x.shape,y.shape,x_ test.shape,y_ test.shape
Out[15]: ((50000, 32,32,3),(50000, 1),(10000, 32,32, 3),(10000, 1))
In [16]: x.min(),x.max()
0ut[16]: (0,255)
In [17]: y[:4]
Out[17]:
array([[6],
[9],
[9],
[4]],dtype=uint8)
tf.data. Dataset
In [5]: (x,y),(x_test,y_test)=keras.datasets.cifar10.load_data()
In [6]: db=tf.data.Dataset. from_tensor_slices(x_test)
In [7]: next(iter(db)).shape
Out[7]: Tensor Shape([32, 32, 3])
# can not use [x_test,y_test]
In [9]: db=tf.data.Dataset. from_tensor_slices((x_ test,y_test))
In [11]:
next(iter(db))[0].shape
Out[11]: TensorShape([32, 32,3])
shuffle
In [12]: db=tf.data.Dataset.from_tensor_slices((x_test,y._test))
In [13]: db=db.shuffle(10000)
map
In [16]: def preprocess(x,y):
x=tf.cast(x, dtype=tf.f loat32)/255.
y=tf.cast(y, dtype=tf. int32)
y=tf.one_ hot(y, depth=10)
return x,y
In [17]: db2=db.map(preprocess)
In [18]: res=next(iter(db2))
In [19]: res[0].shape, res[1].shape
Out[19]: (TensorShape([32, 32,3]),TensorShape([1, 10]))
In [20]: res[1][:2]
0ut[20]: <tf.Tensor: id=58,
shape=(1, 10),dtype=float32, numpy=array([[1., 0.,
0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>
batch
- 只需提供batch的大小
In [21]: db3=db2.batch(32)
In [25]: res=next(iter(db3))
In [26]: res[0].shape, res[1]. shape
Out[26]: (Tensor Shape([32, 32,32, 3]),Tensor Shape([32, 1,10]))
StopIteration
In [27]: db_iter = iter(db3)
In [28]: while True:
next(db_iter)
OutOfRangeError Traceback (most recent call last)
StopIteration:
repeat
- 控制迭代的次数
In
[29]: db4=db3. repeat()
In
[30]: db4=db3. repeat(2)
以fashion_minist为例
def prepare._mnist_features_and_labels(x, y):
tf.cast(x,tf.float32) / 255.0
y = tf.cast(y,tf.int64)
return x,y
def mnist_dataset():
(x, y),(x_val, y_val) = datasets.fashion_mnist.load_ data()
y = tf.one. hot(y, depth=10)
y. _val = tf.one_ hot(y._val, depth=10)
ds = tf. data.Dataset.from tensor_slices((x, y))
ds = ds. map(prepare_mnist features_and_labels)
ds = ds. shuffle( 60000). batch(100)
ds_ _val = tf.data.Dataset.from_tensor_slices((x_val, y._val))
ds_ _val = ds_val.map(prepare_mnist_features_and_labels)
ds_ val = ds_val.shuffle(10000).batch(100)
return ds,ds_val
全连接层