create encoder attention mask from shape of hidden states
We currently create encoder attention masks (when they're not provided) based on the shape of the inputs to the encoder. This is obviously wrong; sequences can be of different lengths. We now create the encoder attention mask based on the batch_size and sequence_length of the encoder hidden states.
This commit is contained in:
@@ -691,16 +691,18 @@ class BertModel(BertPreTrainedModel):
|
||||
|
||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.is_decoder:
|
||||
if self.config.is_decoder and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(input_shape, device=device)
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
|
||||
if encoder_attention_mask.dim() == 3:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
||||
elif encoder_attention_mask.dim() == 2:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
||||
else:
|
||||
raise ValueError("Wrong shape for input_ids (shape {}) or encoder_attention_mask (shape {})".format(input_shape,
|
||||
raise ValueError("Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})".format(encoder_hidden_shape,
|
||||
encoder_attention_mask.shape))
|
||||
|
||||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
|
||||
Reference in New Issue
Block a user