Tensorflow-hub[例子解析2]


Tensorflow-hub[例子解析1].

3 基于文本词向量的例子

3.1 创建Module

可以从Tensorflow-hub[例子解析1].中看出,hub相对之前减少了更多的工作量。
首先,假设有词向量文本文件

token1 1.0 2.0 3.0 4.0 5.0
token2 2.0 3.0 4.0 5.0 6.0

该例子就是通过读取该文件去生成TF-Hub Module,可以使用如下命令:

python export.py --embedding_file=/tmp/embedding.txt --export_path=/tmp/module

下面就是export.py的源码,通过跟踪代码中以序号进行注释的部分,可以得知Module的操作过程。

# 惯例导入需要的模块
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os
import shutil
import sys
import tempfile

import numpy as np
import tensorflow as tf
import tensorflow_hub as hub

FLAGS = None

EMBEDDINGS_VAR_NAME = "embeddings"


def parse_line(line):
  """该函数是为了解析./tmp/embedding.txt文件的每一行

  Args:
    line: (str) One line of the text embedding file.

  Returns:
    A token string and its embedding vector in floats.
  """
  columns = line.split()
  token = columns.pop(0)
  values = [float(column) for column in columns]
  return token, values


def load(file_path, parse_line_fn):
  """该函数是为了将/tmp/embedding.txt解析为numpy对象,并保存在内存中.

  Args:
    file_path: Path to the text embedding file.
    parse_line_fn: callback function to parse each file line.

  Returns:
    A tuple of (list of vocabulary tokens, numpy matrix of embedding vectors).

  Raises:
    ValueError: if the data in the sstable is inconsistent.
  """
  vocabulary = []
  embeddings = []
  embeddings_dim = None
  for line in tf.gfile.GFile(file_path):
    token, embedding = parse_line_fn(line)
    if not embeddings_dim:
      embeddings_dim = len(embedding)
    elif embeddings_dim != len(embedding):
      raise ValueError(
          "Inconsistent embedding dimension detected, %d != %d for token %s",
          embeddings_dim, len(embedding), token)

    vocabulary.append(token)
    embeddings.append(embedding)

  return vocabulary, np.array(embeddings)

''' 该函数展示了如何使用Module '''
def make_module_spec(vocabulary_file, vocab_size, embeddings_dim,
                     num_oov_buckets, preprocess_text):
  """Makes a module spec to simply perform token to embedding lookups.

  Input of this module is a 1-D list of string tokens. For T tokens input and
  an M dimensional embedding table, the lookup result is a [T, M] shaped Tensor.

  Args:
    vocabulary_file: Text file where each line is a key in the vocabulary.
    vocab_size: The number of tokens contained in the vocabulary.
    embeddings_dim: The embedding dimension.
    num_oov_buckets: The number of out-of-vocabulary buckets.
    preprocess_text: Whether to preprocess the input tensor by removing
      punctuation and splitting on spaces.

  Returns:
    A module spec object used for constructing a TF-Hub module.
  """

  ''' 1 - 先创建函数module_fn:
            通过tf.placeholder作为输入端占位符并构建整个graph;
            调用hub.add_signature()执行类似注册操作'''
  def module_fn():
    """Spec function for a token embedding module."""
    tokens = tf.placeholder(shape=[None], dtype=tf.string, name="tokens")

    embeddings_var = tf.get_variable(
        initializer=tf.zeros([vocab_size + num_oov_buckets, embeddings_dim]),
        name=EMBEDDINGS_VAR_NAME,
        dtype=tf.float32)

    lookup_table = tf.contrib.lookup.index_table_from_file(
        vocabulary_file=vocabulary_file,
        num_oov_buckets=num_oov_buckets,
    )
    ids = lookup_table.lookup(tokens)
    combined_embedding = tf.nn.embedding_lookup(params=embeddings_var, ids=ids)
    hub.add_signature("default", {"tokens": tokens},
                      {"default": combined_embedding})

  ''' 1 - 这个函数如上面的module_fn是互斥的:
             通过tf.placeholder作为输入端占位符并构建整个graph;
             调用hub.add_signature()执行类似注册操作 '''
  def module_fn_with_preprocessing():
    """Spec function for a full-text embedding module with preprocessing."""
    sentences = tf.placeholder(shape=[None], dtype=tf.string, name="sentences")
    # Perform a minimalistic text preprocessing by removing punctuation and
    # splitting on spaces.
    normalized_sentences = tf.regex_replace(
        input=sentences, pattern=r"\pP", rewrite="")
    tokens = tf.string_split(normalized_sentences, " ")

    # In case some of the input sentences are empty before or after
    # normalization, we will end up with empty rows. We do however want to
    # return embedding for every row, so we have to fill in the empty rows with
    # a default.
    tokens, _ = tf.sparse_fill_empty_rows(tokens, "")
    # In case all of the input sentences are empty before or after
    # normalization, we will end up with a SparseTensor with shape [?, 0]. After
    # filling in the empty rows we must ensure the shape is set properly to
    # [?, 1].
    tokens = tf.sparse_reset_shape(tokens)

    embeddings_var = tf.get_variable(
        initializer=tf.zeros([vocab_size + num_oov_buckets, embeddings_dim]),
        name=EMBEDDINGS_VAR_NAME,
        dtype=tf.float32)
    lookup_table = tf.contrib.lookup.index_table_from_file(
        vocabulary_file=vocabulary_file,
        num_oov_buckets=num_oov_buckets,
    )
    sparse_ids = tf.SparseTensor(
        indices=tokens.indices,
        values=lookup_table.lookup(tokens.values),
        dense_shape=tokens.dense_shape)

    combined_embedding = tf.nn.embedding_lookup_sparse(
        params=embeddings_var,
        sp_ids=sparse_ids,
        sp_weights=None,
        combiner="sqrtn")

    hub.add_signature("default", {"sentences": sentences},
                      {"default": combined_embedding})

  ''' 2 - 通过调用hub.create_module_spec()创建ModuleSpec对象 '''
  if preprocess_text:
    return hub.create_module_spec(module_fn_with_preprocessing)
  else:
    return hub.create_module_spec(module_fn)


def export(export_path, vocabulary, embeddings, num_oov_buckets,
           preprocess_text):
  """Exports a TF-Hub module that performs embedding lookups.

  Args:
    export_path: Location to export the module.
    vocabulary: List of the N tokens in the vocabulary.
    embeddings: Numpy array of shape [N+K,M] the first N rows are the
      M dimensional embeddings for the respective tokens and the next K
      rows are for the K out-of-vocabulary buckets.
    num_oov_buckets: How many out-of-vocabulary buckets to add.
    preprocess_text: Whether to preprocess the input tensor by removing
      punctuation and splitting on spaces.
  """
  # Write temporary vocab file for module construction.
  tmpdir = tempfile.mkdtemp()
  vocabulary_file = os.path.join(tmpdir, "tokens.txt")
  with tf.gfile.GFile(vocabulary_file, "w") as f:
    f.write("\n".join(vocabulary))
  vocab_size = len(vocabulary)
  embeddings_dim = embeddings.shape[1]
  spec = make_module_spec(vocabulary_file, vocab_size, embeddings_dim,
                          num_oov_buckets, preprocess_text)

  try:
    ''' 3 - 建立tf.Graph(),并使用hub.Module(spec)进行如y=f(x)的操作'''
    with tf.Graph().as_default():
      m = hub.Module(spec)
      # The embeddings may be very large (e.g., larger than the 2GB serialized
      # Tensor limit).  To avoid having them frozen as constant Tensors in the
      # graph we instead assign them through the placeholders and feed_dict
      # mechanism.
      p_embeddings = tf.placeholder(tf.float32)
      load_embeddings = tf.assign(m.variable_map[EMBEDDINGS_VAR_NAME],
                                  p_embeddings)
      ''' 4 - 建立Session(),进行初始化,训练,迭代等正常操作;最后通过调用module.export(export_path,sess)导出Module'''
      with tf.Session() as sess:
        sess.run([load_embeddings], feed_dict={p_embeddings: embeddings})
        m.export(export_path, sess)
  finally:
    shutil.rmtree(tmpdir)


def maybe_append_oov_vectors(embeddings, num_oov_buckets):
  """Adds zero vectors for oov buckets if num_oov_buckets > 0.

  Since we are assigning zero vectors, adding more that one oov bucket is only
  meaningful if we perform fine-tuning.

  Args:
    embeddings: Embeddings to extend.
    num_oov_buckets: Number of OOV buckets in the extended embedding.
  """
  num_embeddings = np.shape(embeddings)[0]
  embedding_dim = np.shape(embeddings)[1]
  embeddings.resize(
      [num_embeddings + num_oov_buckets, embedding_dim], refcheck=False)


def export_module_from_file(embedding_file, export_path, parse_line_fn,
                            num_oov_buckets, preprocess_text):
  # Load pretrained embeddings into memory.
  vocabulary, embeddings = load(embedding_file, parse_line_fn)

  # Add OOV buckets if num_oov_buckets > 0.
  maybe_append_oov_vectors(embeddings, num_oov_buckets)

  # Export the embedding vectors into a TF-Hub module.
  export(export_path, vocabulary, embeddings, num_oov_buckets, preprocess_text)


def main(_):
  export_module_from_file(FLAGS.embedding_file, FLAGS.export_path, parse_line,
                          FLAGS.num_oov_buckets, FLAGS.preprocess_text)


if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.add_argument(
      "--embedding_file",
      type=str,
      default=None,
      help="Path to file with embeddings.")
  parser.add_argument(
      "--export_path",
      type=str,
      default=None,
      help="Where to export the module.")
  parser.add_argument(
      "--preprocess_text",
      type=bool,
      default=False,
      help="Whether to preprocess the input tensor by removing punctuation and "
      "splitting on spaces. Use this if input is a dense tensor of untokenized "
      "sentences.")
  parser.add_argument(
      "--num_oov_buckets",
      type=int,
      default="1",
      help="How many OOV buckets to add.")
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

从上面创建的例子可以看出,该操作过程与Tensorflow-hub[例子解析1].相似

3.2 使用Module

下面就是使用创建好的Module的代码,这里用了几个test进行测试,通过跟踪下面的序号的注释,可以看出使用也是相当简单

# 导入所需要的模块
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub

import export

_MOCK_EMBEDDING = "\n".join(
    ["cat 1.11 2.56 3.45", "dog 1 2 3", "mouse 0.5 0.1 0.6"])


class ExportTokenEmbeddingTest(tf.test.TestCase):

  def setUp(self):
    self._embedding_file_path = os.path.join(self.get_temp_dir(),
                                             "mock_embedding_file.txt")
    with tf.gfile.GFile(self._embedding_file_path, mode="w") as f:
      f.write(_MOCK_EMBEDDING)

  def testEmbeddingLoaded(self):
    vocabulary, embeddings = export.load(self._embedding_file_path,
                                         export.parse_line)
    self.assertEqual((3,), np.shape(vocabulary))
    self.assertEqual((3, 3), np.shape(embeddings))

  def testExportTokenEmbeddingModule(self):
    ''' 1 - 先调用生成Module的代码,生成一个Module'''
    export.export_module_from_file(
        embedding_file=self._embedding_file_path,
        export_path=self.get_temp_dir(),
        parse_line_fn=export.parse_line,
        num_oov_buckets=1,
        preprocess_text=False)
    ''' 2 - 创建一个tf.Graph():
             调用hub.Module装载Module;
             创建tf.Session()进行初始化,和如y=f(x)进行计算得到结果'''
    with tf.Graph().as_default():
      hub_module = hub.Module(self.get_temp_dir())
      tokens = tf.constant(["cat", "lizard", "dog"])
      embeddings = hub_module(tokens)
      with tf.Session() as session:
        session.run(tf.tables_initializer())
        session.run(tf.global_variables_initializer())
        self.assertAllClose(
            session.run(embeddings),
            [[1.11, 2.56, 3.45], [0.0, 0.0, 0.0], [1.0, 2.0, 3.0]])

  def testExportFulltextEmbeddingModule(self):
    ''' 1 - 先调用生成Module的代码,生成一个Module'''
    export.export_module_from_file(
        embedding_file=self._embedding_file_path,
        export_path=self.get_temp_dir(),
        parse_line_fn=export.parse_line,
        num_oov_buckets=1,
        preprocess_text=True)
    ''' 2 - 创建一个tf.Graph():
             调用hub.Module装载Module;
             创建tf.Session()进行初始化,和如y=f(x)进行计算得到结果'''
    with tf.Graph().as_default():
      hub_module = hub.Module(self.get_temp_dir())
      tokens = tf.constant(["cat", "cat cat", "lizard. dog", "cat? dog", ""])
      embeddings = hub_module(tokens)
      with tf.Session() as session:
        session.run(tf.tables_initializer())
        session.run(tf.global_variables_initializer())
        self.assertAllClose(
            session.run(embeddings),
            [[1.11, 2.56, 3.45], [1.57, 3.62, 4.88], [0.70, 1.41, 2.12],
             [1.49, 3.22, 4.56], [0.0, 0.0, 0.0]],
            rtol=0.02)

if __name__ == "__main__":
  tf.test.main()
posted @ 2018-06-12 13:55  仙守  阅读(1387)  评论(0编辑  收藏  举报