学习九

实现文件的读取

import tensorflow as tf
import os

#批处理大小和队列,数据的数量没有影响,大小之决定这批次取多少数据
def csvread(filelist):
    #1构造文件队列
    file_queue=tf.train.string_input_producer(filelist)
    #2构造阅读器,读取队列数据(读取一行)
    reader=tf.TextLineReader()
    key, value=reader.read(file_queue)
    #对每行内容进行解码
    #record_defaults指定每一个样本的每一列类型,指定默认值
    records=[["None"],["None"]]
    example,label=tf.decode_csv(value,record_defaults=records)
    print(example,label)
    #4想要读取多个数据,就需要批处理batch_size最终决定取多少数据
    example_batch,label_batch=tf.train.batch([example,label],batch_size=9,num_threads=1,capacity=9)
    return example_batch,label_batch


if __name__=="__main__":
    #找到文件,放入列表 路径+名字放入列表中
    file_name=os.listdir("./data/csvdata")
    filelist=[os.path.join("./data/csvdata",file)for file in file_name]
    print(file_name)
    example_batch,label_batch=csvread(filelist)
    #开启会话运行结果
    with tf.Session() as sess:
        #定义线程协调器
        coord=tf.train.Coordinator()
        #开启读取文件的线程
        threads=tf.train.start_queue_runners(sess,coord=coord)
        #打印读取的内容
        print(sess.run([example_batch,label_batch]))
        #回收子线程
        coord.request_stop()
        coord.join(threads)

 

posted on 2020-04-19 14:49  啥123  阅读(128)  评论(0编辑  收藏  举报