From 3520be7824ad11ebc05a393fd90ecfdd4203cfdb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 9 Dec 2019 11:13:09 +0100 Subject: [PATCH] create encoder attention mask from shape of hidden states We currently create encoder attention masks (when they're not provided) based on the shape of the inputs to the encoder. This is obviously wrong; sequences can be of different lengths. We now create the encoder attention mask based on the batch_size and sequence_length of the encoder hidden states. --- transformers/modeling_bert.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index 1ee3e3f097..8295cf4664 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -691,17 +691,19 @@ class BertModel(BertPreTrainedModel): # If a 2D ou 3D attention mask is provided for the cross-attention # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder: + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(input_shape, device=device) + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) if encoder_attention_mask.dim() == 3: encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] elif encoder_attention_mask.dim() == 2: encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] else: - raise ValueError("Wrong shape for input_ids (shape {}) or encoder_attention_mask (shape {})".format(input_shape, - encoder_attention_mask.shape)) + raise ValueError("Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})".format(encoder_hidden_shape, + encoder_attention_mask.shape)) encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0