根据训练好的Transformer模型,得到注意力矩阵,并对注意力进行可视化
首先安装:tensorflow 1.13.1 + tensor2tensor 1.13.1
可视化,请在Jupyter notebook中运行。该代码根据tensor2tensor/tensor2tensor/visualization/visualization.py修改得到
# coding=utf-8
# Copyright 2020 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared code for visualizing transformer attentions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
# To register the hparams set
from tensor2tensor import models # pylint: disable=unused-import
from tensor2tensor import problems
from tensor2tensor.utils import registry
from tensor2tensor.utils import trainer_lib
import tensorflow.compat.v1 as tf
from tensor2tensor.utils import usr_dir
EOS_ID = 1
class AttentionVisualizer2(object):
"""Helper object for creating Attention visualizations."""
def __init__(
self, hparams_set,hparams,t2t_usr_dir, model_name, data_dir, problem_name, beam_size=1):
inputs, targets, samples, att_mats = build_model(
hparams_set,hparams, t2t_usr_dir, model_name, data_dir, problem_name, beam_size=beam_size)
# Fetch the problem
ende_problem = problems.problem(problem_name)
encoders = ende_problem.feature_encoders(data_dir)
self.inputs = inputs
self.targets = targets
self.att_mats = att_mats
self.samples = samples
self.encoders = encoders
def encode(self, input_str):
"""Input str to features dict, ready for inference."""
inputs = self.encoders["inputs"].encode(input_str) + [EOS_ID]
batch_inputs = np.reshape(inputs, [1, -1, 1, 1]) # Make it 3D.
return batch_inputs
def decode(self, integers):
"""List of ints to str."""
integers = list(np.squeeze(integers))
return self.encoders["targets"].decode(integers)
def encode_list(self, integers):
"""List of ints to list of str."""
integers = list(np.squeeze(integers))
return self.encoders["inputs"].decode_list(integers)
def decode_list(self, integers):
"""List of ints to list of str."""
integers = list(np.squeeze(integers))
return self.encoders["targets"].decode_list(integers)
def get_vis_data_from_string(self, sess, input_string):
"""Constructs the data needed for visualizing attentions.
Args:
sess: A tf.Session object.
input_string: The input sentence to be translated and visualized.
Returns:
Tuple of (
output_string: The translated sentence.
input_list: Tokenized input sentence.
output_list: Tokenized translation.
att_mats: Tuple of attention matrices; (
enc_atts: Encoder self attention weights.
A list of `num_layers` numpy arrays of size
(batch_size, num_heads, inp_len, inp_len)
dec_atts: Decoder self attention weights.
A list of `num_layers` numpy arrays of size
(batch_size, num_heads, out_len, out_len)
encdec_atts: Encoder-Decoder attention weights.
A list of `num_layers` numpy arrays of size
(batch_size, num_heads, out_len, inp_len)
)
"""
encoded_inputs = self.encode(input_string)
# Run inference graph to get the translation.
out = sess.run(self.samples, {
self.inputs: encoded_inputs,
})
# Run the decoded translation through the training graph to get the
# attention tensors.
att_mats = sess.run(self.att_mats, {
self.inputs: encoded_inputs,
self.targets: np.reshape(out, [1, -1, 1, 1]),
})
output_string = self.decode(out)
input_list = self.encode_list(encoded_inputs)
output_list = self.decode_list(out)
return output_string, input_list, output_list, att_mats
def build_model(hparams_set, hparams,t2t_usr_dir, model_name, data_dir, problem_name, beam_size=1):
"""Build the graph required to fetch the attention weights.
Args:
hparams_set: HParams set to build the model with.
model_name: Name of model.
data_dir: Path to directory containing training data.
problem_name: Name of problem.
beam_size: (Optional) Number of beams to use when decoding a translation.
If set to 1 (default) then greedy decoding is used.
Returns:
Tuple of (
inputs: Input placeholder to feed in ids to be translated.
targets: Targets placeholder to feed to translation when fetching
attention weights.
samples: Tensor representing the ids of the translation.
att_mats: Tensors representing the attention weights.
)
"""
print(model_name)
usr_dir.import_usr_dir(t2t_usr_dir)
hparams = trainer_lib.create_hparams(
hparams_set,hparams, data_dir=data_dir, problem_name=problem_name)
# print(hparams)
translate_model = registry.model(model_name)(
hparams, tf.estimator.ModeKeys.EVAL)
inputs = tf.placeholder(tf.int32, shape=(1, None, 1, 1), name="inputs")
targets = tf.placeholder(tf.int32, shape=(1, None, 1, 1), name="targets")
translate_model({
"inputs": inputs,
"targets": targets,
})
# Must be called after building the training graph, so that the dict will
# have been filled with the attention tensors. BUT before creating the
# inference graph otherwise the dict will be filled with tensors from
# inside a tf.while_loop from decoding and are marked unfetchable.
atts = get_att_mats(translate_model,model_name)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
samples = translate_model.infer({
"inputs": inputs,
}, beam_size=beam_size)["outputs"]
return inputs, targets, samples, atts
def get_att_mats(translate_model,model_name):
"""Get's the tensors representing the attentions from a build model.
The attentions are stored in a dict on the Transformer object while building
the graph.
Args:
translate_model: Transformer object to fetch the attention weights from.
Returns:
Tuple of attention matrices; (
enc_atts: Encoder self attention weights.
A list of `num_layers` numpy arrays of size
(batch_size, num_heads, inp_len, inp_len)
dec_atts: Decoder self attetnion weights.
A list of `num_layers` numpy arrays of size
(batch_size, num_heads, out_len, out_len)
encdec_atts: Encoder-Decoder attention weights.
A list of `num_layers` numpy arrays of size
(batch_size, num_heads, out_len, inp_len)
)
"""
enc_atts = []
dec_atts = []
encdec_atts = []
prefix = "%s/body/"%(model_name)
postfix_self_attention = "/multihead_attention/dot_product_attention"
if translate_model.hparams.self_attention_type == "dot_product_relative":
postfix_self_attention = ("/multihead_attention/"
"dot_product_attention_relative")
postfix_encdec = "/multihead_attention/dot_product_attention"
for i in range(translate_model.hparams.num_hidden_layers):
enc_att = translate_model.attention_weights[
"%sencoder/layer_%i/self_attention%s"
% (prefix, i, postfix_self_attention)]
dec_att = translate_model.attention_weights[
"%sdecoder/layer_%i/self_attention%s"
% (prefix, i, postfix_self_attention)]
encdec_att = translate_model.attention_weights[
"%sdecoder/layer_%i/encdec_attention%s" % (prefix, i, postfix_encdec)]
enc_atts.append(enc_att)
dec_atts.append(dec_att)
encdec_atts.append(encdec_att)
return enc_atts, dec_atts, encdec_atts
from IPython.display import display
def call_html():
import IPython
display(IPython.core.display.HTML('''
<script src="/static/components/requirejs/require.js"></script>
<script>
requirejs.config({
paths: {
base: '/static/base',
"d3": "https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min",
jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
},
});
</script>
'''))
import os
from tensor2tensor import problems
from tensor2tensor.bin import t2t_decoder # To register the hparams set
# from tensor2tensor.utils import registry
from tensor2tensor.utils import trainer_lib
from tensor2tensor.visualization import attention
# from src.visualization import visualization
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
# HParams
problem_name = 'translate_ende_wmt32k' #数据
data_dir = os.path.expanduser('/home/usrname/collaboration/t2t_data/%s'%(problem_name)) #数据路径
model_name = "collaboration" #模型名称
hparams_set = "collaboration_base" #模型类型
hparams = 'max_length=128,num_hidden_layers=6,usedegray=1.0,reuse_n=0' #自定义参数 (根据自己需求)
t2t_usr_dir = './src/' #用户自定义模型model的路径
visualizer = AttentionVisualizer2(hparams_set,hparams, t2t_usr_dir,model_name, data_dir, problem_name, beam_size=1)
tf.Variable(0, dtype=tf.int64, trainable=False, name='global_step')
接着继续运行:
saver = tf.train.Saver() with tf.Session() as sess: ckpt = 'averaged.ckpt-0' #checkpoint路径 print(ckpt) saver.restore(sess, ckpt)
#可视化样本 # input_sentence = "It is in this spirit that a majority of American governments have passed new laws since 2009 making the registration or voting process more difficult." input_sentence = "The Law will never be perfect, but its application should be just - this is what we are missing, in my opinion." output_string, inp_text, out_text, att_mats = visualizer.get_vis_data_from_string(sess, input_sentence) print(output_string) call_html() attention.show(inp_text, out_text, *att_mats)
可视化结果:

浙公网安备 33010602011771号