From 28d0ba35d73d5b8b31fdadd72686a3ac078a6143 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Fri, 8 Nov 2019 11:22:19 +0100 Subject: [PATCH 1/2] only init encoder_attention_mask if stack is decoder We currently initialize `encoder_attention_mask` when it is `None`, whether the stack is that of an encoder or a decoder. Since this may lead to bugs that are difficult to tracks down, I added a condition that assesses whether the current stack is a decoder. --- transformers/modeling_bert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index 7c2c6f4602..6bd5ab6a2e 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -656,7 +656,7 @@ class BertModel(BertPreTrainedModel): if attention_mask is None: attention_mask = torch.ones(input_shape, device=device) - if encoder_attention_mask is None: + if self.config.is_decoder and encoder_attention_mask is None: encoder_attention_mask = torch.ones(input_shape, device=device) if token_type_ids is None: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) From cd286c2145221f3d1372aef103d0bc3ed03879da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Fri, 8 Nov 2019 11:31:16 +0100 Subject: [PATCH 2/2] add condition around mask transformation --- transformers/modeling_bert.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index 6bd5ab6a2e..893ec51015 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -656,8 +656,6 @@ class BertModel(BertPreTrainedModel): if attention_mask is None: attention_mask = torch.ones(input_shape, device=device) - if self.config.is_decoder and encoder_attention_mask is None: - encoder_attention_mask = torch.ones(input_shape, device=device) if token_type_ids is None: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) @@ -688,13 +686,19 @@ class BertModel(BertPreTrainedModel): # If a 2D ou 3D attention mask is provided for the cross-attention # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] - if encoder_attention_mask.dim() == 3: - encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] - if encoder_attention_mask.dim() == 2: - encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + if self.config.is_decoder: + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(input_shape, device=device) - encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility - encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 + if encoder_attention_mask.dim() == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if encoder_attention_mask.dim() == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + + encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 + else: + encoder_extended_attention_mask = None # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head