import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
import tensorflow as tf
sess=tf.Session()
iris=datasets.load_iris()
#print(iris)
target=np.array([1. if x==0 else 0. for x in iris.target ])
#print(target.shape)
iris_data=np.array([ [x[2],x[3]] for x in iris.data] ) #shape[none,2]
#声明批量
batch_size=20
#宽度长度 均为【NOne,1】
x1_data=tf.placeholder(tf.float32,shape=[None,1])
x2_data=tf.placeholder(tf.float32,shape=[None,1])
y_target=tf.placeholder(tf.float32,shape=[None,1])
#初始化类型为1,1 可以和x1 相乘
A=tf.Variable(tf.random_normal(shape=[1,1]))
b=tf.Variable(tf.random_normal(shape=[1,1]))
#线性模型 x1=x2*A+b --》 f=x1-x2*A-b
my_mult=tf.matmul(x2_data,A)
my_add=tf.add(my_mult,b)
my_output=tf.subtract(x1_data,my_add)
#损失函数 (交叉熵损失函数 非归一化 常用于两类验证)
sigmoid_logits=tf.nn.sigmoid_cross_entropy_with_logits(labels=y_target,logits=my_output)
#梯度下降取最小值 (选择学习率0.05)
my_opt=tf.train.GradientDescentOptimizer(0.05)
train_step=my_opt.minimize(sigmoid_logits)
#初始化所有声明的变量
init=tf.global_variables_initializer()
sess.run(init)
#迭代100次 训练模型 传入三种数据 长度 宽度 和目标
for i in range(1500):
#随机获取批量数据 根据(iris_data)的长度已经确定
rand_index=np.random.choice(len(iris_data) ,batch_size)
#shape=[batchsize,1]
x1_rand= np.array([[iris_data[x][0]] for x in rand_index],dtype=np.float32)
x2_rand = np.array([[iris_data[x][1]] for x in rand_index],dtype=np.float32)
y_rand=np.array([[target[x]] for x in rand_index],dtype=np.float32)
sess.run(train_step,feed_dict={x1_data:x1_rand,x2_data:x2_rand,y_target:y_rand})
if (i+1)%200 ==0:
print('Step %s :A= %s ; b=%s ' % ( i+1,str(sess.run(A)), str(sess.run(b)) ))
#保存A,b
[[slope]]=sess.run(A)
[[intercept]]=sess.run(b)
x=np.linspace(0,3,num=50)
abline=[]
for i in x:
abline.append(slope*i+intercept)
#重新选取数据 从目标1中选取 长度 宽度
set1_x=[ a[1] for i,a in enumerate(iris_data) if target[i] == 1]
set1_y=[ a[0] for i,a in enumerate(iris_data) if target[i] == 1]
#重新选取数据 从目标0中 选取长度宽度
no_set1_x=[ a[1] for i,a in enumerate(iris_data) if target[i] == 0]
no_set1_y=[ a[0] for i,a in enumerate(iris_data) if target[i] == 0]
plt.plot(set1_x,set1_y,'rx',ms=10,mew=2,label='set1')
#plt.clabel('set1')
plt.plot(no_set1_x,no_set1_y,'ro',label='set0')
#plt.clabel('set0')
plt.plot(x,abline,'b-',label='my')
plt.xlim([0.0,2.7])
plt.ylim([0.0,7.1])
plt.xlabel('length')
plt.ylabel('width')
plt.legend(loc='lower right')
plt.show()