From ee5de0ba449d638da704e1c03ffcc20a930f5589 Mon Sep 17 00:00:00 2001 From: Oleksiy Syvokon Date: Thu, 6 Feb 2020 11:43:31 +0000 Subject: [PATCH] BERT decoder: Fix causal mask dtype. PyTorch < 1.3 requires multiplication operands to be of the same type. This was violated when using default attention mask (i.e., attention_mask=None in arguments) given BERT in the decoder mode. In particular, this was breaking Model2Model and made tutorial from the quickstart failing. --- src/transformers/modeling_bert.py | 4 ++-- tests/test_modeling_bert.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index 5c032f05e5..6cfbe3d00a 100644 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -733,8 +733,8 @@ class BertModel(BertPreTrainedModel): seq_ids = torch.arange(seq_length, device=device) causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] causal_mask = causal_mask.to( - torch.long - ) # not converting to long will cause errors with pytorch version < 1.3 + attention_mask.dtype + ) # causal and attention masks must have same type with pytorch version < 1.3 extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] else: extended_attention_mask = attention_mask[:, None, None, :] diff --git a/tests/test_modeling_bert.py b/tests/test_modeling_bert.py index 946246ea2e..20b53f6fad 100644 --- a/tests/test_modeling_bert.py +++ b/tests/test_modeling_bert.py @@ -438,6 +438,34 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_bert_model_as_decoder(*config_and_inputs) + def test_bert_model_as_decoder_with_default_input_mask(self): + # This regression test was failing with PyTorch < 1.3 + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) = self.model_tester.prepare_config_and_inputs_for_decoder() + + input_mask = None + + self.model_tester.create_and_check_bert_model_as_decoder( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs)