import tensorflow as tf
import os
'''
tensorflow读取文件的流程,每一步每一种数据都有对应封装好的API进行处理:
1、构造一个文件队列:A文件,B文件,C文件,每个文件内有100个样本
2、读取队列内容:一个一个样本读取,二进制文件就是指定一个样本的byte读取,图片就是一张一张
3、进行解码
4、批处理:将样本一个一个的放入一个队列中,达到一定数量后,一次性进行处理
'''
def readcsv(filelist):
"""读取csv文件"""
# 1.构造文件队列
file_queue = tf.train.string_input_producer(filelist)
# 2.构造CSV阅读器读取队列数据(按一行),read返回一个元组,一个是路径一个是样本内容
reader = tf.TextLineReader()
key, value = reader.read(file_queue)
# 对每行内容进行解码,field_delim分隔符默认“,”,record_defaults指定每一个样本的每一列类型,并设置默认值对缺失值进行填充
# CSV数据中有几列就应该有几个列表,"None"表示字符串并同时指定默认值是None,1表示int类型,并同时指定默认值是1,如果是4.5则是float类型,默认值就是4.5
# decode_csv返回的是每一个样本每一个的值,返回的是op列表
records = [["None"],[1]]
example, label = tf.decode_csv(value, field_delim=',', record_defaults=records)
# 批处理,batch_size从队列读取的批处理大小,num_threads使用几个线程处理,capacity批处理队列大小,tf.train.batch返回的是两个元素的op,一个op存储着一列九行数据
example_batch, label_batch = tf.train.batch([example, label], batch_size=5, num_threads=1, capacity=10)
return example_batch, label_batch
if __name__ == "__main__":
# 构造文件列表
file_name = os.listdir("./data/csvdata")
filelist = [os.path.join("./data/csvdata", file) for file in file_name]
example_batch, label_batch = readcsv(filelist)
# 开启会话
with tf.Session() as sess:
# 定义线程协调器
coord = tf.train.Coordinator()
# 开启读取文件的线程
thd = tf.train.start_queue_runners(sess, coord=coord, start=True)
# 打印读取内容
print(sess.run([example_batch, label_batch]))
# 回收子线程
coord.request_stop()
coord.join(thd)