From 09cfd122353347da7a62eb4f5af75d83b955684f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 10 Oct 2019 10:15:27 +0200 Subject: [PATCH] remove and do the branching in --- transformers/modeling_bert.py | 68 +++-------------------------------- 1 file changed, 5 insertions(+), 63 deletions(-) diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index db8847f39e..94791571cd 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -174,67 +174,6 @@ class BertEmbeddings(nn.Module): return embeddings -class BertGeneralAttention(nn.Module): - def __init__(self, config): - super(BertGeneralAttention, self).__init__() - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (config.hidden_size, config.num_attention_heads)) - self.output_attentions = config.output_attentions - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = nn.Linear(config.hidden_size, self.all_head_size) - self.key = nn.Linear(config.hidden_size, self.all_head_size) - self.value = nn.Linear(config.hidden_size, self.all_head_size) - - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward(self, query, key, value, attention_mask=None, head_mask=None): - mixed_query_layer = self.query(query) - mixed_key_layer = self.key(key) - mixed_value_layer = self.value(value) - - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in BertModel forward() function) - attention_scores = attention_scores + attention_mask - - # Normalize the attention scores to probabilities. - attention_probs = nn.Softmax(dim=-1)(attention_scores) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) - - outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,) - return outputs - - class BertSelfAttention(nn.Module): def __init__(self, config): super(BertSelfAttention, self).__init__() @@ -259,10 +198,13 @@ class BertSelfAttention(nn.Module): x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) - def forward(self, hidden_states, attention_mask=None, head_mask=None): - mixed_query_layer = self.query(hidden_states) + def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None): mixed_key_layer = self.key(hidden_states) mixed_value_layer = self.value(hidden_states) + if encoder_hidden_states: # if encoder-decoder attention + mixed_query_layer = self.query(encoder_hidden_states) + else: + mixed_query_layer = self.query(hidden_states) query_layer = self.transpose_for_scores(mixed_query_layer) key_layer = self.transpose_for_scores(mixed_key_layer)