BERT decoder: Fix causal mask dtype.
PyTorch < 1.3 requires multiplication operands to be of the same type. This was violated when using default attention mask (i.e., attention_mask=None in arguments) given BERT in the decoder mode. In particular, this was breaking Model2Model and made tutorial from the quickstart failing.
This commit is contained in:
committed by
Lysandre Debut
parent
bed38d3afe
commit
ee5de0ba44
@@ -733,8 +733,8 @@ class BertModel(BertPreTrainedModel):
|
||||
seq_ids = torch.arange(seq_length, device=device)
|
||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||
causal_mask = causal_mask.to(
|
||||
torch.long
|
||||
) # not converting to long will cause errors with pytorch version < 1.3
|
||||
attention_mask.dtype
|
||||
) # causal and attention masks must have same type with pytorch version < 1.3
|
||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||
else:
|
||||
extended_attention_mask = attention_mask[:, None, None, :]
|
||||
|
||||
@@ -438,6 +438,34 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
self.model_tester.create_and_check_bert_model_as_decoder(*config_and_inputs)
|
||||
|
||||
def test_bert_model_as_decoder_with_default_input_mask(self):
|
||||
# This regression test was failing with PyTorch < 1.3
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
|
||||
input_mask = None
|
||||
|
||||
self.model_tester.create_and_check_bert_model_as_decoder(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs)
|
||||
|
||||
Reference in New Issue
Block a user