golang调用tensor flow模型

1. 安装Go版TensorFlow

TensorFlow 提供了一个 Go API,该 API 特别适合加载用 Python 创建的模型并在 Go 应用中运行这些模型。

安装TensorFlow C库

下载地址

TensorFlow C 库网址
Linux
Linux CPU only https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.14.0.tar.gz
Linux GPU support https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-1.14.0.tar.gz
macOS
macOS CPU only https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-1.14.0.tar.gz
Windows
Windows CPU only https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-1.14.0.zip
Windows GPU only https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-1.14.0.zip

解压 :

tar -C $dir -xzf tar_file

添加到动态库:

export LIBRARY_PATH=$LIBRARY_PATH:$dir/lib

export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$dir/lib

如果你已经解压到了/usr/local下,则不需要配置LIBRARY_PATH和LD_LIBRARY_PATH,只需要执行sudo ldconfig即可

安装 TensorFlow Go

go get github.com/tensorflow/tensorflow/tensorflow/go

2. 用Python训练tensor flow模型并保存

注意:python的tensor flow版本不能高于go的tensor flow版本,否则go加载模型文件时会报错。

import tensorflow as tf
from keras import backend as K

sess = tf.Session()
K.set_session(sess)

def build_deep_cross():
    inputs = []
    for i, feature_name in enumerate(feature_names):
        cate_in = Input((1,), name=feature_name)
        inputs.append(cate_in)
        # 此处略去很多代码
    cross_net = build_cross_net(f_dim_vectors)  # CrossNet
    deep_out = build_dnn(f_dim_vectors, continuous_input)  # 深层网络
    # 结合CrossNet和深层网络
    concat_cross_deep = Concatenate()([cross_net, deep_out])
    outputs = Dense(1, activation="sigmoid", name="output_layer")(concat_cross_deep)
    # 留意模型的inputs和outputs,写Golang时要用
    model = Model(inputs=inputs, outputs=outputs)
    solver = Adam(lr=0.01, decay=0.1)
    model.compile(optimizer=solver, loss='binary_crossentropy', metrics=['acc'])
    return model

model=build_deep_cross()
model.fit(X_train, Y_train, batch_size=256, epochs=10)
# 专门为Golang保存一个模型
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder("dcnModel")
# 必须为模型打个Tag,否则golang无法加载
builder.add_meta_graph_and_variables(sess, ["myTag"])
# 保存
builder.save()

 

3. Golang加载tensor flow模型

package main

import (
    tf "github.com/tensorflow/tensorflow/tensorflow/go"
    "strconv"
    "strings"
    "sync"
    "fmt"
)

type DCN struct {
    model        *tf.SavedModel
    featureNames []string
}

var (
    dcn     *DCN
    dcnOnce sync.Once
)

func GetDCNInstance(modelFile string, tags []string) *DCN {
    if dcn != nil {
        return dcn
    }
    dcnOnce.Do(func() {
        dcn = &DCN{}
        //LoadSavedModel时使用的go tensorflow版本不能低于tf.saved_model.builder.SavedModelBuilder时使用的tensorflow版本
        if model, err := tf.LoadSavedModel(modelFile, tags, nil); err == nil {
            dcn.model = model
            dcn.featureNames = []string{"age", "work_year", "gender"}
            //第一次执行model.Session.Run很耗时,所以初始化后先预热一下
            X := []float32{0.31,0.09,1.0}
            input := [][]float32{X}
            dcn.Predict(input)
        } else {
            fmt.Printf("read dcn model file %s failed: %v", modelFile, err)
            return
        }
    })
    return dcn
}

//Predict 预测点击率。X的连续特征需要事先做好归一化,离散特征要转成index
func (self DCN) Predict(X [][]float32) []float32 {
    if len(X[0]) != len(self.featureNames) {
        fmt.Printf("feature number of x is %d, but should be %d", len(X[0]), len(self.featureNames))
        return nil
    }
    input_layer := make(map[tf.Output]*tf.Tensor)
    for i := 0; i < len(X[0]); i++ { //第i列
        input := [][]float32{}
        for j := 0; j < len(X); j++ { //第j行
            input = append(input, []float32{X[j][i]})
        }
        tensor, _ := tf.NewTensor(input)
        // python版tensorflow/keras中定义的输入层input_layer
        out := self.model.Graph.Operation(self.featureNames[i]).Output(0)
        input_layer[out] = tensor
    }
    output_layer := []tf.Output{
        //python版tensorflow/keras中定义的输出层output_layer
        self.model.Graph.Operation("output_layer/Sigmoid").Output(0),
    }
    if result, err := self.model.Session.Run(input_layer, output_layer, nil); err == nil { //不论是1条数据还是300条数据,执行该行代码只需要2毫秒
        scores := result[0].Value().([][]float32)
        rect := make([]float32, len(scores))
        for i, arr := range scores {
            rect[i] = arr[0]
        }
        return rect
    } else {
        fmt.Printf("predict failed: %v", err)
        return nil
    }
}

func main(){
    dcn := rank.GetDCNInstance("dcnModel", []string{"myTag"})
    X := []float32{0.31,0.09,1.0}
    input := [][]float32{X}
    scores := dcn.Predict(input)
}

 

posted @ 2019-08-16 14:31  张朝阳  阅读(2680)  评论(0编辑  收藏