2021寒假(26)
TensorFlow K近邻算法
实验原理
knn的基本原理:
KNN是通过计算不同特征值之间的距离进行分类。
整体的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。K通常是不大于20的整数。KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在分类决策上只依据最邻近的一个或者几个样本的类别来决定待分类样本所属的类别。
KNN算法要解决的核心问题是K值选择,它会直接影响分类结果。如果选择较大的K值,就相当于用较大领域中的训练实例进行预测,其优点是可以减少学习的估计误差,但缺点是学习的近似误差会增大。如果选择较小的K值,就相当于用较小的领域中的训练实例进行预测,“学习”近似误差会减小,只有与输入实例较近或相似的训练实例才会对预测结果起作用,与此同时带来的问题是“学习”的估计误差会增大,换句话说,K值的减小就意味着整体模型变得复杂,容易发生过拟合;
使用tensorflow进行KNN算法的整体过程是先设计计算图,然后运行会话,执行计算图的过程,整个过程的数据可见性比较差。以上精确度的计算以及真实标签和预测标签的比较结果其实使用numpy和python的变量。
完整代码
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#导入实验所需的数据
mnist = input_data.read_data_sets("E:/PycharmProjects/TensorFlow/基础/1.3/data",one_hot = True)
#设置训练参数
learning_rate=0.01
training_epochs=25
batch_size=100
display_step=1
#构造计算图,使用占位符placeholder函数构造变量x,y,
x=tf.placeholder(tf.float32,[None,784])
y=tf.placeholder(tf.float32,[None,10])
#使用Variable函数,设置模型的初始权重
W=tf.Variable(tf.zeros([784,10]))
b=tf.Variable(tf.zeros([10]))
#构造逻辑回归模型
pred=tf.nn.softmax(tf.matmul(x,W)+b)
#构造代价函数cost
cost=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))
#使用梯度下降法求最小值,即最优解
optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
#初始化全部变量
init=tf.global_variables_initializer()
#.使用tf.Session()创建Session会话对象,会话封装了Tensorflow运行时的状态和控制
with tf.Session() as sess:
sess.run(init)
#调用会话对象sess的run方法,运行计算图,即开始训练模型
for epoch in range(training_epochs):
avg_cost = 0
total_batch = int(mnist.train.num_examples / batch_size)
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
_, c = sess.run([optimizer, cost], feed_dict={x: batch_xs, y: batch_ys})
avg_cost += c / total_batch
if (epoch+1) % display_step == 0:
print("Epoch:", '%04d' % (epoch + 1), "Cost:","{:.09f}".format(avg_cost))
print("Optimization Finished!")
#测试模型
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
#评估模型的准确度
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print("Accuracy:", accuracy.eval({x: mnist.test.images[:3000], y: mnist.test.labels[:3000]}))


浙公网安备 33010602011771号