BERT decoder: Fix causal mask dtype.
PyTorch < 1.3 requires multiplication operands to be of the same type. This was violated when using default attention mask (i.e., attention_mask=None in arguments) given BERT in the decoder mode. In particular, this was breaking Model2Model and made tutorial from the quickstart failing.
This commit is contained in:
committed by
Lysandre Debut
parent
bed38d3afe
commit
ee5de0ba44
@@ -733,8 +733,8 @@ class BertModel(BertPreTrainedModel):
|
|||||||
seq_ids = torch.arange(seq_length, device=device)
|
seq_ids = torch.arange(seq_length, device=device)
|
||||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||||
causal_mask = causal_mask.to(
|
causal_mask = causal_mask.to(
|
||||||
torch.long
|
attention_mask.dtype
|
||||||
) # not converting to long will cause errors with pytorch version < 1.3
|
) # causal and attention masks must have same type with pytorch version < 1.3
|
||||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||||
else:
|
else:
|
||||||
extended_attention_mask = attention_mask[:, None, None, :]
|
extended_attention_mask = attention_mask[:, None, None, :]
|
||||||
|
|||||||
@@ -438,6 +438,34 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||||
self.model_tester.create_and_check_bert_model_as_decoder(*config_and_inputs)
|
self.model_tester.create_and_check_bert_model_as_decoder(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_bert_model_as_decoder_with_default_input_mask(self):
|
||||||
|
# This regression test was failing with PyTorch < 1.3
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
) = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||||
|
|
||||||
|
input_mask = None
|
||||||
|
|
||||||
|
self.model_tester.create_and_check_bert_model_as_decoder(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
def test_for_masked_lm(self):
|
def test_for_masked_lm(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs)
|
self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user