# encoding: UTF-8
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data as mnist_data
import tensorflow as tf
from tensorflow.python.platform import gfile
import os
print("Tensorflow version " + tf.__version__)
print(tf.__path__)
# tf.set_random_seed(0)
# # 输入mnist数据
# mnist = mnist_data.read_data_sets("data", one_hot=True)
# #输入数据
# x = tf.placeholder("float", [None, 784])
# y_ = tf.placeholder("float", [None,10])
# #权值输入
# W = tf.Variable(tf.zeros([784,10]))
# b = tf.Variable(tf.zeros([10]))
# #神经网络输出
# y = tf.nn.softmax(tf.matmul(x,W) + b)
# #设置交叉熵
# cross_entropy = -tf.reduce_sum(y_*tf.log(y))
# #设置训练模型
# learningRate = 0.005
# train_step = tf.train.GradientDescentOptimizer(learningRate).minimize(cross_entropy)
# init = tf.initialize_all_variables()
# sess = tf.Session()
# sess.run(init)
# itnum = 1000;
# batch_size = 100;
# for i in range(itnum):
# if i % 100 == 0:
# print("the index " + str(i + 1) + " train")
# batch_xs, batch_ys = mnist.train.next_batch(batch_size)
# sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
# correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
# accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
# print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
def train():
height = 28
width = 28
inchannel = 1
outchannel = 2
#conv0 (64, 112, 112) kernel (3, 3) stride (1, 1) pad (1, 1)
wkernel = 3
stride = 1
pad = 1
dilate = 1
w = np.arange(wkernel * wkernel * inchannel * outchannel).reshape((outchannel,inchannel,wkernel,wkernel))
b = np.array([0])
data = np.arange(height * width * inchannel).reshape((1,inchannel,height,width))
print('input:',data)
print('weight:',w)
data = data.transpose(0,3,2,1)
w = w.transpose(3,2,1,0)
# print('input:',data)
# print('inputshape:',data.shape)
# print('weight:',w)
# print('weight:',w.shape)
input = tf.Variable(data, dtype=np.float32, name="input")
#input_reshape = tf.reshape(input, [1,inchannel,height,width])
filter = tf.Variable(w, dtype=np.float32,name="weight")
conv = tf.nn.conv2d(input, filter, strides=[1, stride, stride, 1], padding='SAME', name = "conv")
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
#print("input: \n", sess.run(input))
#input_reshape = sess.run(input).transpose(0,3,2,1)
#print("input_reshape: \n", input_reshape)
#print("filter: \n", sess.run(filter))
#filter_reshape = sess.run(filter).transpose(3,2,1,0)
#print("filter_reshape: \n", filter_reshape)
#print("conv ", sess.run(conv))
conv_reshape = sess.run(conv).transpose(0,3,2,1)
print("conv_reshape: \n", conv_reshape)
# tf_prelu_reshape = sess.run(tf_prelu).transpose(0,3,2,1)
# print("tf_prelu_reshape: \n", tf_prelu_reshape)
# tf_bn_reshape = sess.run(tf_bn).transpose(0,3,2,1)
# print("tf_bn_reshape: \n", tf_bn_reshape)
export_dir = "log"
saver = tf.train.Saver()
step = 200
import os
checkpoint_file = os.path.join(export_dir, 'model.ckpt')
saver.save(sess, checkpoint_file, global_step=step)
graph = tf.get_default_graph()
checkpoint_file = os.path.join(export_dir, 'model.ckpt-200.meta')
_ = tf.train.import_meta_graph(checkpoint_file)
summary_write = tf.summary.FileWriter(export_dir , graph)
if __name__ == '__main__':
train()