Scala实现Tensorflow模型预测

IDEA创建Scala项目

1、安装scala插件

2、下载Scala SDK压缩包并解压:https://www.scala-lang.org/download/all.html

3、新建Project -> Scala -> IDEA -> Scala SDK选择解压根目录即可

4、新建Scala Class -> 选 Object,然后右键执行即可

object HelloWorld {
  def main(args: Array[String]): Unit = {
    println("Hello World")
  }
}

 

注册模型推理UDF
问题1、解决A master URL must be set in your configuration错误:

https://blog.csdn.net/shenlanzifa/article/details/42679577

问题2、使用tf.SavedModelBundle.load加载tf模型报错"Could not find SavedModel .pb or .pbtxt"

tensorflow有两种PB格式的模型,一种是FrozenPB,一种是SaveModel PB ,后者是带签名的。

解决办法:使用python将SaveModel PB转成FrozenPB

   pb_path = r"./inception-2015-12-05/classify_image_graph_def.pb"
    with tf.gfile.GFile(pb_path,'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')
    with tf.Session() as session:
        # 先获取模型的输入和输出名
        print(tf.get_default_graph().get_operations())
        # 将FrozenPB转成SaveModel PB,默认添加签名"serve"
        tf.saved_model.simple_save(
            session,
            "./savedmodel/",
            inputs={"image": tf.get_default_graph().get_tensor_by_name("DecodeJpeg/contents:0")},
            outputs={"scores": tf.get_default_graph().get_tensor_by_name("softmax:0")})
        # 加载SaveModel PB模型:标签默认为"serve"
        tf.saved_model.loader.load(session, [tf.saved_model.tag_constants.SERVING], "./savedmodel/")
        # Java Tensorflow API
        tf.SavedModelBundle.load("C:/Users/xxx/Desktop/model_debug/savedmodel","serve")
Spark使用UDF

详见:https://www.cnblogs.com/cc11001100/p/9463909.html

UDF入参类型Array[Byte]可以直接传入Java byte[]数组

import org.{tensorflow => tf}
import java.nio.file.{Files, Paths}
​
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import java.util.Arrays
​
import scala.collection.JavaConversions.asScalaBuffer
​
object PredictUDF extends Serializable {
  def main(args: Array[String]): Unit = {
​
    // 构造预测函数,并将其注册成sparkSQL的UDF
    // 图片和模型均传bytes
    val classificationPredict = (imgData: Array[Byte], modelData: Array[Byte]) => {
      // load model
      val graph = new tf.Graph()
      graph.importGraphDef(modelData)
      val session = new tf.Session(graph)
​
      // get input and output
      val opIter = graph.operations()
      var inputName: String = null
      if (opIter.hasNext) {
        inputName = opIter.next().name() + ":0"
      }
      var outputName: String = null
      while (opIter.hasNext) {
        outputName = opIter.next().name() + ":0"
      }
​
      // model predict
      var y_pred: Array[(Float, Int)] = null
      if (inputName != null && outputName != null) {
        val x = tf.Tensor.create(imgData)
        val y = session.runner().feed(inputName, x).fetch(outputName).run().get(0)
        val scores = Array.ofDim[Float](y.shape()(0).toInt, y.shape()(1).toInt)
        y.copyTo(scores)
        y_pred = scores.map(x => (x.max, x.indexOf(x.max)))
      }
      session.close()
      y_pred
    }
​
    // test1:直接测试函数
    // val modelPath:String = "C:/Users/xxx/Desktop/model_debug/inception-2015-12-05/classify_image_graph_def.pb"
    // val imgData = Files.readAllBytes(Paths.get("C:/Users/xxx/Desktop/model_debug/inception-2015-12-05/cropped_panda2.jpg"))
    // val modelData = Files.readAllBytes(Paths.get(modelPath))
    // var y_preds = classificationPredict2(imgData, modelData)
    // println(y_preds)
// register model predict udf
    val sparkConf = new SparkConf()
      .set("spark.serializer","org.apache.spark.serializer.KryoSerializer")
    val spark = SparkSession
      .builder().config(sparkConf)
      .appName("TfDataFrame")
      .enableHiveSupport()
      .getOrCreate()
    import spark.implicits._
    // test udf function
    spark.udf.register("classificationPredict", classificationPredict)
    val modelData = Files.readAllBytes(Paths.get("C:/Users/xxx/Desktop/model_debug/inception-2015-12-05/classify_image_graph_def.pb"))
    val imgData = Files.readAllBytes(Paths.get("C:/Users/xxx/Desktop/测试图片/COCO_val2014_000000000397.jpg"))
    Seq((imgData, modelData)).toDF("imgData", "modelData").createOrReplaceTempView("resources")
    spark.sql("select imgData, classificationPredict(imgData, modelData) as y_preds from resources").show()
  }
}

注:多张图片的预测还有待研究。

posted @ 2020-06-29 15:39  Coding练习生  阅读(325)  评论(0编辑  收藏  举报