Attention Mechanisms-multihead-attention课后题
1、Visualize attention weights of multiple heads in this experiment.
from matplotlib import pyplot as plt out = attention.attention.attention_weights.detach() # out shape is (batch_size*head_size, queries_size, key_value_size) print(out.shape) # out1 shape is (batch_size, head_size, queries_size, key_value_size) out1=out.reshape(-1, num_heads, out.shape[1],out.shape[2]) print(out1.shape) d2l.show_heatmaps(out1, xlabel='Keys', ylabel='Queries') print(out1)
2、Suppose that we have a trained model based on multi-head attention and we want to prune least important attention heads to increase the prediction speed. How can we design experiments to measure the importance of an attention head?
1、DotProductAttention的attention为[batch_size**num_heads,queries,num_hiddens/num_heads],reshape为[batch_size,num_heads,queries,num_hiddens/num_heads]
2、设置head的权重的参数,参数为[1, num_heads,1,1]
3、将DotProductAttention的attention乘以head权重[batch_size,num_heads,queries,num_hiddens/num_heads]
4、将各num_heads进行拼接
5、运行一次,然后就可以获取定义的head的权重
class MultiHeadAttention(nn.Module):
def __init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout)
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
self.head_attention = nn.Parameter(torch.rand(1, num_heads, 1, 1))
def forward(self, queries, keys, values, valid_lens):
# Shape of `queries`, `keys`, or `values`:
# (`batch_size`, no. of queries or key-value pairs, `num_hiddens`)
# Shape of `valid_lens`:
# (`batch_size`,) or (`batch_size`, no. of queries)
# After transposing, shape of output `queries`, `keys`, or `values`:
# (`batch_size` * `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
print("queries,", queries.shape)
print("self.W_q(queries),", self.W_q(queries).shape)
print("keys,", keys.shape)
print("self.W_k(keys),", self.W_k(keys).shape)
print("values,", values.shape)
print("self.W_v(values),", self.W_v(values).shape)
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
print("*queries", queries.shape)
print("*keys", keys.shape)
print("*values", values.shape)
if valid_lens is not None:
# On axis 0, copy the first item (scalar or vector) for
# `num_heads` times, then copy the next item, and so on
valid_lens = torch.repeat_interleave(valid_lens,
repeats=self.num_heads,
dim=0)
# Shape of `output`: (`batch_size` * `num_heads`, no. of queries,
# `num_hiddens` / `num_heads`)
output = self.attention(queries, keys, values, valid_lens)
#***********************************different
# Shape is (`batch_size` , `num_heads`, no. of queries,`num_hiddens` / `num_heads`)
output = output.reshape(-1, num_heads, output.shape[1], output.shape[2])
mm = nn.Softmax(dim=1)
self.head_attenion_tran = mm(self.head_attention)
# Shape is (`batch_size` , `num_heads`, no. of queries,`num_hiddens`)
# *****************************************************************************************
# Shape of self.head_attenion_tran is (1, num_head, 1, 1)
output_concat3 = (output*self.head_attenion_tran)
print("output_concat3", output_concat3.shape)
output_concat3 = output_concat3.permute(0, 2, 1, 3)
output_concat = output_concat3.reshape(X.shape[0], X.shape[1], -1)
print("output_concat", output_concat.shape)
# *****************************************************************************************
return self.W_o(output_concat)

浙公网安备 33010602011771号