Scala实现Tensorflow模型预测
IDEA创建Scala项目
2、下载Scala SDK压缩包并解压:https://www.scala-lang.org/download/all.html
3、新建Project -> Scala -> IDEA -> Scala SDK选择解压根目录即可
4、新建Scala Class -> 选 Object,然后右键执行即可
object HelloWorld { def main(args: Array[String]): Unit = { println("Hello World") } }
注册模型推理UDF
问题1、解决A master URL must be set in your configuration错误:
https://blog.csdn.net/shenlanzifa/article/details/42679577
问题2、使用tf.SavedModelBundle.load加载tf模型报错"Could not find SavedModel .pb or .pbtxt"
tensorflow有两种PB格式的模型,一种是FrozenPB,一种是SaveModel PB ,后者是带签名的。
解决办法:使用python将SaveModel PB转成FrozenPB
pb_path = r"./inception-2015-12-05/classify_image_graph_def.pb" with tf.gfile.GFile(pb_path,'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def, name='') with tf.Session() as session: # 先获取模型的输入和输出名 print(tf.get_default_graph().get_operations()) # 将FrozenPB转成SaveModel PB,默认添加签名"serve" tf.saved_model.simple_save( session, "./savedmodel/", inputs={"image": tf.get_default_graph().get_tensor_by_name("DecodeJpeg/contents:0")}, outputs={"scores": tf.get_default_graph().get_tensor_by_name("softmax:0")}) # 加载SaveModel PB模型:标签默认为"serve" tf.saved_model.loader.load(session, [tf.saved_model.tag_constants.SERVING], "./savedmodel/") # Java Tensorflow API tf.SavedModelBundle.load("C:/Users/xxx/Desktop/model_debug/savedmodel","serve")
Spark使用UDF
详见:https://www.cnblogs.com/cc11001100/p/9463909.html
UDF入参类型Array[Byte]可以直接传入Java byte[]数组
import org.{tensorflow => tf} import java.nio.file.{Files, Paths} import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession import java.util.Arrays import scala.collection.JavaConversions.asScalaBuffer object PredictUDF extends Serializable { def main(args: Array[String]): Unit = { // 构造预测函数,并将其注册成sparkSQL的UDF // 图片和模型均传bytes val classificationPredict = (imgData: Array[Byte], modelData: Array[Byte]) => { // load model val graph = new tf.Graph() graph.importGraphDef(modelData) val session = new tf.Session(graph) // get input and output val opIter = graph.operations() var inputName: String = null if (opIter.hasNext) { inputName = opIter.next().name() + ":0" } var outputName: String = null while (opIter.hasNext) { outputName = opIter.next().name() + ":0" } // model predict var y_pred: Array[(Float, Int)] = null if (inputName != null && outputName != null) { val x = tf.Tensor.create(imgData) val y = session.runner().feed(inputName, x).fetch(outputName).run().get(0) val scores = Array.ofDim[Float](y.shape()(0).toInt, y.shape()(1).toInt) y.copyTo(scores) y_pred = scores.map(x => (x.max, x.indexOf(x.max))) } session.close() y_pred } // test1:直接测试函数 // val modelPath:String = "C:/Users/xxx/Desktop/model_debug/inception-2015-12-05/classify_image_graph_def.pb" // val imgData = Files.readAllBytes(Paths.get("C:/Users/xxx/Desktop/model_debug/inception-2015-12-05/cropped_panda2.jpg")) // val modelData = Files.readAllBytes(Paths.get(modelPath)) // var y_preds = classificationPredict2(imgData, modelData) // println(y_preds) // register model predict udf val sparkConf = new SparkConf() .set("spark.serializer","org.apache.spark.serializer.KryoSerializer") val spark = SparkSession .builder().config(sparkConf) .appName("TfDataFrame") .enableHiveSupport() .getOrCreate() import spark.implicits._ // test udf function spark.udf.register("classificationPredict", classificationPredict) val modelData = Files.readAllBytes(Paths.get("C:/Users/xxx/Desktop/model_debug/inception-2015-12-05/classify_image_graph_def.pb")) val imgData = Files.readAllBytes(Paths.get("C:/Users/xxx/Desktop/测试图片/COCO_val2014_000000000397.jpg")) Seq((imgData, modelData)).toDF("imgData", "modelData").createOrReplaceTempView("resources") spark.sql("select imgData, classificationPredict(imgData, modelData) as y_preds from resources").show() } }
注:多张图片的预测还有待研究。