如何用加载本地数据库为tf.data.Dataset格式

云端的数据库存储在google的服务器,所以无法通过tfds.load('mnist', split='train')这样的方式直接从云端读取,而且tfds.load('mnist',split='train',data_dir='...')也无法实现本地加载。

下面简单描述如何从本地加载数据

 

一、MNIST数据库

 1. 通过gzip模块打开本地的train-images-idx3-ubyte.gz文件为numpy数据

 2. 通过tf.data.Dataset.from_tensor_slices读取numpy数据

 3.将image数据和label数据打包到一起,并分别打包的数据转换成字典

代码如下:(tf版本为2.12.0)

import gzip
import numpy as np
import tensorflow as tf

# 解压并读取图像数据为 numpy array
# 注意:根据你的idx3-ubyte文件的实际数据维度reshape,如下是28x28图像的例子
with gzip.open('train-images-idx3-ubyte.gz', 'rb') as f:
    images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28 * 28)

# 解压并读取标签数据为 numpy array
with gzip.open('train-labels-idx1-ubyte.gz', 'rb') as f:
    labels = np.frombuffer(f.read(), np.uint8, offset=8)

# 转换 numpy array 到TF数据集对象
dataset_images = tf.data.Dataset.from_tensor_slices(images)
dataset_labels = tf.data.Dataset.from_tensor_slices(labels)

# 组合图像和标签数据集到一起并变成字典形式
dataset = tf.data.Dataset.zip((dataset_images, dataset_labels))
dataset = dataset.map(lambda x, y: {"image": x, "label": y})

# 现在,你可以像之前示例代码一样处理dataset数据集:
for example in dataset.take(1):
    image, label = example['image'], example['label']

 

posted on 2023-10-07 21:42  博闻强记2010  阅读(55)  评论(0编辑  收藏  举报

导航