[Bart] _prepare_decoder_inputs should use large negative (#3158)
This commit is contained in:
@@ -37,6 +37,7 @@ if is_torch_available():
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
shift_tokens_right,
|
||||
_prepare_bart_decoder_inputs,
|
||||
LARGE_NEGATIVE,
|
||||
)
|
||||
from transformers.tokenization_bart import BartTokenizer
|
||||
|
||||
@@ -303,6 +304,38 @@ class BartHeadTests(unittest.TestCase):
|
||||
lm_model = BartForConditionalGeneration(config).eval().to(torch_device).half()
|
||||
lm_model.generate(input_ids, attention_mask)
|
||||
|
||||
def test_prepare_bart_decoder_inputs(self):
|
||||
config, *_ = self._get_config_and_data(output_past=False)
|
||||
input_ids = _long_tensor(([4, 4, 2])) # only used for .device if decoder_input_ids is passed
|
||||
decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]])
|
||||
ignore = LARGE_NEGATIVE
|
||||
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(config, input_ids, decoder_input_ids)
|
||||
expected_mask = torch.tensor(
|
||||
[
|
||||
[0, ignore, ignore],
|
||||
[0, 0, ignore],
|
||||
[ignore, ignore, ignore], # never attend to the final token, because its pad
|
||||
]
|
||||
).to(input_ids.device)
|
||||
self.assertEqual(decoder_attn_mask.size(), (1, 1, 3, 3))
|
||||
self.assertTrue(torch.eq(expected_mask, decoder_attn_mask).all())
|
||||
|
||||
# Test no causal mask
|
||||
config, *_ = self._get_config_and_data(output_past=True)
|
||||
expected_just_padding_mask = torch.tensor(
|
||||
[[0, 0, 0], [0, 0, 0], [ignore, ignore, ignore]] # never attend to the final token, because its pad
|
||||
).to(input_ids.device)
|
||||
_, decoder_attn_mask_no_causal_mask = _prepare_bart_decoder_inputs(config, input_ids, decoder_input_ids)
|
||||
self.assertEqual(decoder_attn_mask_no_causal_mask.size(), (1, 1, 3, 3))
|
||||
self.assertTrue(torch.eq(expected_just_padding_mask, decoder_attn_mask_no_causal_mask).all())
|
||||
|
||||
decoder_input_ids = _long_tensor([[0, 26388, 4133, 2]])
|
||||
# Attend to everything if no pad tokens and no causal mask
|
||||
_, decoder_attn_mask_no_padding_no_causal_mask = _prepare_bart_decoder_inputs(
|
||||
config, input_ids, decoder_input_ids
|
||||
)
|
||||
self.assertTrue(torch.eq(decoder_attn_mask_no_padding_no_causal_mask, 0).all())
|
||||
|
||||
|
||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||
|
||||
Reference in New Issue
Block a user