create encoder attention mask from shape of hidden states

We currently create encoder attention masks (when they're not provided)
based on the shape of the inputs to the encoder. This is obviously
wrong; sequences can be of different lengths. We now create the encoder
attention mask based on the batch_size and sequence_length of the
encoder hidden states.
This commit is contained in:
Rémi Louf
2019-12-09 11:13:09 +01:00
parent 0cb163865a
commit 3520be7824

View File

@@ -691,16 +691,18 @@ class BertModel(BertPreTrainedModel):
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder:
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(input_shape, device=device)
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
if encoder_attention_mask.dim() == 3:
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
elif encoder_attention_mask.dim() == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
else:
raise ValueError("Wrong shape for input_ids (shape {}) or encoder_attention_mask (shape {})".format(input_shape,
raise ValueError("Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})".format(encoder_hidden_shape,
encoder_attention_mask.shape))
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility