导航

Keras bug in model.predict

Posted on 2018-07-25 14:41  MaHaLo  阅读(348)  评论(0编辑  收藏  举报

When I use Keras to predict behind a web service, it occurred an error. and the error message is like flowing:

self._make_predict_function()
  File "/usr/local/lib/python3.4/dist-packages/keras/engine/training.py", line 679, in _make_predict_function
    **self._function_kwargs)
  File "/usr/local/lib/python3.4/dist-packages/keras/backend/tensorflow_backend.py", line 615, in function
    return Function(inputs, outputs, updates=updates)
  File "/usr/local/lib/python3.4/dist-packages/keras/backend/tensorflow_backend.py", line 589, in __init__
    with tf.control_dependencies(self.outputs):
  File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/framework/ops.py", line 3192, in control_dependencies
    return get_default_graph().control_dependencies(control_inputs)
  File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/framework/ops.py", line 2993, in control_dependencies
    c = self.as_graph_element(c)
  File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/framework/ops.py", line 2291, in as_graph_element
    raise ValueError("Tensor %s is not an element of this graph." % obj)
ValueError: Tensor Tensor("Sigmoid_2:0", shape=(?, 17), dtype=float32) is not an element of this graph.

This error appears when using tensorflow backend. After search this error, I found it's a bug when using tensorflow graph cross threads.

 

Here's how to fix the bug:

#Right after loading or constructing your model, save the TensorFlow graph:
graph = tf.get_default_graph()
  
#In the other thread (or perhaps in an asynchronous event handler), do:
global graph
with graph.as_default():
    (... do inference here ...)

 

refer to: https://github.com/fchollet/keras/issues/2397