[Bart] _prepare_decoder_inputs should use large negative (#3158)

This commit is contained in:
Sam Shleifer
2020-03-06 16:06:36 -05:00
committed by GitHub
parent 0416d437fb
commit ed37f9fa4f
2 changed files with 39 additions and 6 deletions

View File

@@ -37,6 +37,7 @@ if is_torch_available():
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
shift_tokens_right,
_prepare_bart_decoder_inputs,
LARGE_NEGATIVE,
)
from transformers.tokenization_bart import BartTokenizer
@@ -303,6 +304,38 @@ class BartHeadTests(unittest.TestCase):
lm_model = BartForConditionalGeneration(config).eval().to(torch_device).half()
lm_model.generate(input_ids, attention_mask)
def test_prepare_bart_decoder_inputs(self):
config, *_ = self._get_config_and_data(output_past=False)
input_ids = _long_tensor(([4, 4, 2])) # only used for .device if decoder_input_ids is passed
decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]])
ignore = LARGE_NEGATIVE
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(config, input_ids, decoder_input_ids)
expected_mask = torch.tensor(
[
[0, ignore, ignore],
[0, 0, ignore],
[ignore, ignore, ignore], # never attend to the final token, because its pad
]
).to(input_ids.device)
self.assertEqual(decoder_attn_mask.size(), (1, 1, 3, 3))
self.assertTrue(torch.eq(expected_mask, decoder_attn_mask).all())
# Test no causal mask
config, *_ = self._get_config_and_data(output_past=True)
expected_just_padding_mask = torch.tensor(
[[0, 0, 0], [0, 0, 0], [ignore, ignore, ignore]] # never attend to the final token, because its pad
).to(input_ids.device)
_, decoder_attn_mask_no_causal_mask = _prepare_bart_decoder_inputs(config, input_ids, decoder_input_ids)
self.assertEqual(decoder_attn_mask_no_causal_mask.size(), (1, 1, 3, 3))
self.assertTrue(torch.eq(expected_just_padding_mask, decoder_attn_mask_no_causal_mask).all())
decoder_input_ids = _long_tensor([[0, 26388, 4133, 2]])
# Attend to everything if no pad tokens and no causal mask
_, decoder_attn_mask_no_padding_no_causal_mask = _prepare_bart_decoder_inputs(
config, input_ids, decoder_input_ids
)
self.assertTrue(torch.eq(decoder_attn_mask_no_padding_no_causal_mask, 0).all())
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""