only init encoder_attention_mask if stack is decoder
We currently initialize `encoder_attention_mask` when it is `None`, whether the stack is that of an encoder or a decoder. Since this may lead to bugs that are difficult to tracks down, I added a condition that assesses whether the current stack is a decoder.
This commit is contained in:
@@ -656,7 +656,7 @@ class BertModel(BertPreTrainedModel):
|
|||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones(input_shape, device=device)
|
attention_mask = torch.ones(input_shape, device=device)
|
||||||
if encoder_attention_mask is None:
|
if self.config.is_decoder and encoder_attention_mask is None:
|
||||||
encoder_attention_mask = torch.ones(input_shape, device=device)
|
encoder_attention_mask = torch.ones(input_shape, device=device)
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||||
|
|||||||
Reference in New Issue
Block a user