import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
import random
from PIL import Image
from tensorflow.examples.tutorials.mnist import input_data
sess=tf.Session()
mnist= input_data.read_data_sets("MNIST_data/",one_hot=True)
#本例包含10个类别
train_size=1000
test_size=102
rand_train_indices=np.random.choice(len(mnist.train.images),train_size,replace=False)
rand_test_indices=np.random.choice(len(mnist.train.images),test_size,replace=False)
x_vals_train=mnist.train.images[rand_train_indices]
x_vals_test=mnist.train.images[rand_test_indices]
y_vals_train=mnist.train.labels[rand_train_indices]
y_vals_test=mnist.train.labels[rand_test_indices]
k=4
batch_size=6
x_data_train=tf.placeholder(shape=[None,784],dtype=tf.float32)
x_data_test=tf.placeholder(shape=[None,784],dtype=tf.float32)
y_target_train=tf.placeholder(shape=[None,10],dtype=tf.float32)
y_target_test=tf.placeholder(shape=[None,10],dtype=tf.float32)
#L1距离 shape=(6, 1000) sub.shape=(1000,784) - (6,1,10)=(6,1000,784)
distance=tf.reduce_sum(tf.abs(tf.subtract(x_data_train,tf.expand_dims(x_data_test,1))),reduction_indices=2)
#top K (6, 4)
top_k_xvals,top_k_indices=tf.nn.top_k(tf.negative(distance),k=k)
#(6, 4, 10) = gather((1000,10),(6,4) )
prediction_indices=tf.gather(y_target_train,top_k_indices)
#shape=(6, 10)
count_of_prediction=tf.reduce_sum(prediction_indices,reduction_indices=1)
#预测模型 shape=(6,)
prediction=tf.arg_max(count_of_prediction,dimension=1)
num_loop=int(np.ceil(len(x_vals_test)/batch_size))
test_output=[]
actual_vals=[]
for i in range(num_loop):
min_index=i*batch_size
max_index=min((i+1)*batch_size,len(x_vals_test))
#获取数据
x_batch=x_vals_test[min_index:max_index]
y_batch = y_vals_test[min_index:max_index]
predictions=sess.run(prediction,feed_dict={x_data_test:x_batch,x_data_train:x_vals_train,y_target_test:y_batch,y_target_train:y_vals_train})
test_output.extend(predictions)
actual_vals.extend(np.argmax(y_batch,axis=1))
#精确度预测
accuracy=sum( 1./test_size for i in range(test_size) if test_output[i]==actual_vals[i])
print("Accuarcy: "+str(accuracy))
actuals=np.argmax(y_batch,axis=1)
for i in range(len(actuals)):
plt.subplot(2,3,i+1)
plt.imshow(np.reshape(x_batch[i],[28,28]),cmap="Greys_r")
plt.title('Actual: '+str(actuals[i])+' Pred:'+str(predictions[i]),fontsize=10)
frame=plt.gca()
frame.axes.get_xaxis().set_visible(False)
frame.axes.get_yaxis().set_visible(False)
plt.show()