if is_cross_attention and past_key_value isnotNone: # reuse k,v, cross_attentions key_layer = past_key_value[0] value_layer = past_key_value[1] attention_mask = encoder_attention_mask elif is_cross_attention: key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) attention_mask = encoder_attention_mask elif past_key_value isnotNone: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) key_layer = torch.cat([past_key_value[0], key_layer], dim=2) value_layer = torch.cat([past_key_value[1], value_layer], dim=2) else: key_layer = self.transpose_for_scores(self.key(hidden_states)) # b x s x 768 -> b x 12 x s x 64 value_layer = self.transpose_for_scores(self.value(hidden_states)) # b x s x 768 -> b x 12 x s x 64
query_layer = self.transpose_for_scores(mixed_query_layer) # b x s x 768 -> b x 12 x s x 64
# Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) #QK^T b x 12 x s x s
attention_scores = attention_scores / math.sqrt(self.attention_head_size) # QK^T/sqrt(64) if attention_mask isnotNone: # Apply the attention mask is (precomputed for all layers in BertModel forward() function) attention_scores = attention_scores + attention_mask # attention mask
# Normalize the attention scores to probabilities. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) #dropout
# Mask heads if we want to if head_mask isnotNone: attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer) # * V
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() #b x 12 x s x 64 -> b x s x 12 x 64 new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape)#b x s x 12 x 64 -> b x s x 768
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder: outputs = outputs + (past_key_value,) return outputs