tf.strided_slice_and_tf.fill_and_tf.concat
tf.strided_slice,tf.fill,tf.concat使用实例
其中,我们需要对tensor data进行切片,tf.strided_slice使用方法请参考
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# process_decoder_input
data = tf.constant(
[
[4, 5, 20, 20, 22, 3], [17, 19, 28, 8, 7, 3], [5, 13, 15, 24, 26, 3], [5, 20, 25, 4, 5, 3],
[4, 12, 14, 15, 5, 3], [4, 7, 7, 16, 23, 3], [7, 8, 10, 13, 19, 3]
])
batch_size = 6
ending = tf.strided_slice(data, [0, 0], [6, -1], [1, 1])
fill = tf.fill([6, 1], 2)
decoder_input = tf.concat([tf.fill([batch_size, 1], 2), ending], 1)
# Decoder
# 先对target数据进行预处理
def process_decoder_input(data, vocab_to_int, batch_size):
"""
补充<GO>,并移除最后一个字符
"""
# cut掉最后一个字符
ending = tf.strided_slice(data, [0, 0], [batch_size, -1], [1, 1])
fill = tf.fill([batch_size, 1], vocab_to_int['<GO>'])
# vocab_to_int['<GO>']在本例中是2,经过在列维度上的合并,每个序列都是以GO(对应数值为2)开头
decoder_input = tf.concat([fill, ending], 1)
return ending, fill, decoder_input
data = tf.constant(
[
[4, 5, 20, 20, 22, 3],
[17, 19, 28, 8, 7, 3],
[5, 13, 15, 24, 26, 3],
[5, 20, 25, 4, 5, 3],
[4, 12, 14, 15, 5, 3],
[4, 7, 7, 16, 23, 3],
[7, 8, 10, 13, 19, 3]
]
)
target_letter_to_int = {
'<PAD>': 0, '<UNK>': 1, '<GO>': 2, '<EOS>': 3,
'a': 4, 'b': 5, 'c': 6, 'd': 7, 'e': 8, 'f': 9, 'g': 10, 'h': 11, 'i': 12, 'j': 13, 'k': 14, 'l': 15, 'm': 16,
'n': 17, 'o': 18, 'p': 19, 'q': 20, 'r': 21, 's': 22, 't': 23, 'u': 24, 'v': 25, 'w': 26, 'x': 27, 'y': 28, 'z': 29}
batch_size = 6
ending, fill, decoder_input = process_decoder_input(data, target_letter_to_int, batch_size)
with tf.Session() as sess: # 初始化会话
sess.run(tf.global_variables_initializer())
print('ending:\n', sess.run(ending))
print('fill:\n', sess.run(fill))
print('decoder_input:\n', sess.run(decoder_input))
结果如下:
''' ending: [[ 4 5 20 20 22] [17 19 28 8 7] [ 5 13 15 24 26] [ 5 20 25 4 5] [ 4 12 14 15 5] [ 4 7 7 16 23]] fill: [[2] [2] [2] [2] [2] [2]] decoder_input: [[ 2 4 5 20 20 22] [ 2 17 19 28 8 7] [ 2 5 13 15 24 26] [ 2 5 20 25 4 5] [ 2 4 12 14 15 5] [ 2 4 7 7 16 23]] '''

浙公网安备 33010602011771号