Merge pull request #1770 from huggingface/initi-encoder-mask
Only init encoder_attention_mask if stack is decoder
This commit is contained in:
@@ -660,8 +660,6 @@ class BertModel(BertPreTrainedModel):
|
|||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones(input_shape, device=device)
|
attention_mask = torch.ones(input_shape, device=device)
|
||||||
if encoder_attention_mask is None:
|
|
||||||
encoder_attention_mask = torch.ones(input_shape, device=device)
|
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||||
|
|
||||||
@@ -692,13 +690,19 @@ class BertModel(BertPreTrainedModel):
|
|||||||
|
|
||||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||||
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
||||||
if encoder_attention_mask.dim() == 3:
|
if self.config.is_decoder:
|
||||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
if encoder_attention_mask is None:
|
||||||
if encoder_attention_mask.dim() == 2:
|
encoder_attention_mask = torch.ones(input_shape, device=device)
|
||||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
|
||||||
|
|
||||||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
if encoder_attention_mask.dim() == 3:
|
||||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
||||||
|
if encoder_attention_mask.dim() == 2:
|
||||||
|
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
||||||
|
|
||||||
|
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||||
|
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
||||||
|
else:
|
||||||
|
encoder_extended_attention_mask = None
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
|
|||||||
Reference in New Issue
Block a user