MXNet学习:预测结果-识别单张图片

用到了model里的FeedForward.load和predict

import os
import mxnet as mx
import numpy as np
import Image
from collections import namedtuple

Batch = namedtuple('Batch',['data'])
synsets = [0,1,2,3,4,5,6,7,8,9]


def predict(img_url,model,synsets):
    img = Image.open(img_url)
    img = img.convert('L')
    img = img.resize((28,28),Image.ANTIALIAS)
    img.save(img_url)
    img = np.asarray(img,dtype=np.uint8)
    img = img.reshape(1,1,28,28).astype(np.float32)/255
    val = mx.io.NDArrayIter(data=img)
    res =  model.predict(X=val)[0]
    for i in range(0,10):
        print "%d: %.2f" % (synsets[i],res[i])


model = mx.model.FeedForward.load('MNIST_MXNet',100)
while(1):
    img_url = raw_input("Enter the img_url: ")
    predict(img_url,model,synsets)

save时用到的是 model.save('MNIST_MXNet',100) 

posted @ 2016-12-26 01:07  Mu001999  阅读(3307)  评论(0)    收藏  举报