From ed37f9fa4f59e93549cd5306c6d98ee3940cf1df Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Fri, 6 Mar 2020 16:06:36 -0500 Subject: [PATCH] [Bart] _prepare_decoder_inputs should use large negative (#3158) --- src/transformers/modeling_bart.py | 12 +++++------ tests/test_modeling_bart.py | 33 +++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index 204796df4d..4000e159f9 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -65,7 +65,7 @@ BART_INPUTS_DOCSTRING = r""" If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify. See diagram 1 in the paper for more info on the default strategy """ -LARGE_NEGATIVE = -1e4 +LARGE_NEGATIVE = -1e8 def _prepare_bart_decoder_inputs( @@ -144,18 +144,18 @@ def _check_shapes(shape_1, shape2): raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2)) -def _combine_masks(key_padding_mask, attn_mask, targ_size): +def _combine_masks(key_padding_mask, causal_lm_mask, targ_size): # targ_size = (bsz, tgt_len, src_len) a = torch.zeros(targ_size) b = torch.zeros(targ_size) if key_padding_mask is not None: # (bsz, tgt_len) -> targ_size _check_shapes(key_padding_mask.shape, targ_size[:2]) reshaped = key_padding_mask.unsqueeze(2).expand(*targ_size) - a[reshaped] = 1e-8 + a[reshaped] = LARGE_NEGATIVE - if attn_mask is not None: # (tgt_len, src_len) -> targ_size - _check_shapes(attn_mask.shape, targ_size[-2:]) - b = attn_mask.unsqueeze(0).expand(*targ_size) + if causal_lm_mask is not None: # (tgt_len, src_len) -> targ_size + _check_shapes(causal_lm_mask.shape, targ_size[-2:]) + b = causal_lm_mask.unsqueeze(0).expand(*targ_size) return (a + b).unsqueeze(1).clamp(LARGE_NEGATIVE,) diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 374c7b840f..b1d878e6c5 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -37,6 +37,7 @@ if is_torch_available(): BART_PRETRAINED_MODEL_ARCHIVE_MAP, shift_tokens_right, _prepare_bart_decoder_inputs, + LARGE_NEGATIVE, ) from transformers.tokenization_bart import BartTokenizer @@ -303,6 +304,38 @@ class BartHeadTests(unittest.TestCase): lm_model = BartForConditionalGeneration(config).eval().to(torch_device).half() lm_model.generate(input_ids, attention_mask) + def test_prepare_bart_decoder_inputs(self): + config, *_ = self._get_config_and_data(output_past=False) + input_ids = _long_tensor(([4, 4, 2])) # only used for .device if decoder_input_ids is passed + decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]]) + ignore = LARGE_NEGATIVE + decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(config, input_ids, decoder_input_ids) + expected_mask = torch.tensor( + [ + [0, ignore, ignore], + [0, 0, ignore], + [ignore, ignore, ignore], # never attend to the final token, because its pad + ] + ).to(input_ids.device) + self.assertEqual(decoder_attn_mask.size(), (1, 1, 3, 3)) + self.assertTrue(torch.eq(expected_mask, decoder_attn_mask).all()) + + # Test no causal mask + config, *_ = self._get_config_and_data(output_past=True) + expected_just_padding_mask = torch.tensor( + [[0, 0, 0], [0, 0, 0], [ignore, ignore, ignore]] # never attend to the final token, because its pad + ).to(input_ids.device) + _, decoder_attn_mask_no_causal_mask = _prepare_bart_decoder_inputs(config, input_ids, decoder_input_ids) + self.assertEqual(decoder_attn_mask_no_causal_mask.size(), (1, 1, 3, 3)) + self.assertTrue(torch.eq(expected_just_padding_mask, decoder_attn_mask_no_causal_mask).all()) + + decoder_input_ids = _long_tensor([[0, 26388, 4133, 2]]) + # Attend to everything if no pad tokens and no causal mask + _, decoder_attn_mask_no_padding_no_causal_mask = _prepare_bart_decoder_inputs( + config, input_ids, decoder_input_ids + ) + self.assertTrue(torch.eq(decoder_attn_mask_no_padding_no_causal_mask, 0).all()) + def _assert_tensors_equal(a, b, atol=1e-12, prefix=""): """If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""