观看Tensorflow案例实战视频课程22 迭代及测试网络效果
#生成一个训练batch
def get_next_batch(batch_size=128):
batch_x=np.zeros([batch_size,IMAGE_HEIGHT*IMAGE_WIDTH])
batch_y=np.zeros([batch_size,MAX_CAPTCHA*CHAR_SET_LEN])
#有时生成图像大小不是(60,160,3)
def wrap_gen_captcha_text_and_image():
while True:
text,image=gen_captcha_text_and_image()
if image.shape==(60,160,3):
return text,image
for i in range(batch_size):
text,image=wrap_gen_captcha_text_and_image()
image=convert2gray(image)
batch_x[i,:]=image.flatten()/255#(image.flatten()-128)/128 mean为0
batch_y[i,:]=text2vec(text)
return batch_x,batch_y
# 训练
def train_crack_captcha_cnn():
output=crack_captcha_cnn()
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(output,Y))
optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(loss)
predict=tf.reshape(output,[-1,MAX_CAPTCHA,CHAR_SET_LEN])
max_idx_p=tf.argmax(predict,2)
max_idx_l=tf.argmax(tf.reshape(Y,[-1,MAX_CAPTCHA,CHAR_SET_LEN]),2)
correct_pred=tf.equal(max_idx_p,max_idx_l)
accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))
saver=tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
step=0
while True:
batch_x,batch_y=get_next_batch(64)
_,loss_=sess.run([optimizer,loss],feed_dict={X:batch_x,Y:batch_y,keep_prob:0.75})
print(step,loss_)
# 每100 step计算一次准确率
if step%100==0:
batch_x_test,batch_y_test=get_next_batch(100)
acc=sess.run(accuracy,feed_dict={X:batch_x_test,Y:batch_y_test,keep_prob:1.})
print(step,acc)
# 如果准确率大于50%,保存模型,完成训练
if acc>0.85:
saver.save(sess,"./model/crack_captcha.model",global_step=step)
break
step+=1
def crack_captcha(captcha_image):
output=crack_captcha_cnn()
saver=tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess,"./model/crack_captcha.model-1500")
predict=tf.argmax(tf.reshape(output,[-1,MAX_CAPTCHA,CHAR_SET_LEN]),2)
text_list=sess.run(predict,feed_dict={X:[captcha_image],keep_prob:1})
text=text_list[0].tolist()
return text
if __name__=='__main__':
#train=0
train = 1
if train==0:
number = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
#alphabet = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't','u', 'v', 'w', 'x', 'y', 'z']
#ALPHABET = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T','U', 'V', 'W', 'X', 'Y', 'Z']
text,image=gen_captcha_text_and_image()
print("验证码图像channel:",image.shape)#(60,160,3)
#图像大小
IMAGE_HEIGHT=60
IMAGE_WIDTH=160
MAX_CAPTCHA=len(text)
print("验证码文本最长字符数",MAX_CAPTCHA)
#文本转向量
#char_set=number+alphabet+ALPHACET+['_']#如果验证码长度小于4,'_'用来补充
char_set=number
CHAR_SET_LEN=len(char_set)
X=tf.placeholder(tf.float32,[None,IMAGE_HEIGHT*IMAGE_WIDTH])
Y=tf.placeholder(tf.float32,[None,MAX_CAPTCHA*CHAR_SET_LEN])
keep_prob=tf.placeholder(tf.float32)# dropout
train_crack_captcha_cnn()
if train==1:
number = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
IMAGE_HEIGHT=60
IMAGE_WIDTH=160
char_set=number
CHAR_SET_LEN=len(char_set)
text,image=gen_captcha_text_and_image()
f=plt.figure()
ax=f.add_wuplot(111)
ax.text(0.1, 0.9, text, ha='center', va='center', transform=ax.transAxes)
plt.imshow(image)
plt.show()
MAX_CAPTCHA=len(text)
image=convert2gray(image)
image=image.flatten()/255
X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT * IMAGE_WIDTH])
Y = tf.placeholder(tf.float32, [None, MAX_CAPTCHA * CHAR_SET_LEN])
keep_prob = tf.placeholder(tf.float32) # dropout
predict_text=crack_captcha(image)
print("正确:() 预测:()".format(text,predict_text))
浙公网安备 33010602011771号