[Bart] _prepare_decoder_inputs should use large negative (#3158)
This commit is contained in:
@@ -65,7 +65,7 @@ BART_INPUTS_DOCSTRING = r"""
|
|||||||
If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
|
If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
|
||||||
See diagram 1 in the paper for more info on the default strategy
|
See diagram 1 in the paper for more info on the default strategy
|
||||||
"""
|
"""
|
||||||
LARGE_NEGATIVE = -1e4
|
LARGE_NEGATIVE = -1e8
|
||||||
|
|
||||||
|
|
||||||
def _prepare_bart_decoder_inputs(
|
def _prepare_bart_decoder_inputs(
|
||||||
@@ -144,18 +144,18 @@ def _check_shapes(shape_1, shape2):
|
|||||||
raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2))
|
raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2))
|
||||||
|
|
||||||
|
|
||||||
def _combine_masks(key_padding_mask, attn_mask, targ_size):
|
def _combine_masks(key_padding_mask, causal_lm_mask, targ_size):
|
||||||
# targ_size = (bsz, tgt_len, src_len)
|
# targ_size = (bsz, tgt_len, src_len)
|
||||||
a = torch.zeros(targ_size)
|
a = torch.zeros(targ_size)
|
||||||
b = torch.zeros(targ_size)
|
b = torch.zeros(targ_size)
|
||||||
if key_padding_mask is not None: # (bsz, tgt_len) -> targ_size
|
if key_padding_mask is not None: # (bsz, tgt_len) -> targ_size
|
||||||
_check_shapes(key_padding_mask.shape, targ_size[:2])
|
_check_shapes(key_padding_mask.shape, targ_size[:2])
|
||||||
reshaped = key_padding_mask.unsqueeze(2).expand(*targ_size)
|
reshaped = key_padding_mask.unsqueeze(2).expand(*targ_size)
|
||||||
a[reshaped] = 1e-8
|
a[reshaped] = LARGE_NEGATIVE
|
||||||
|
|
||||||
if attn_mask is not None: # (tgt_len, src_len) -> targ_size
|
if causal_lm_mask is not None: # (tgt_len, src_len) -> targ_size
|
||||||
_check_shapes(attn_mask.shape, targ_size[-2:])
|
_check_shapes(causal_lm_mask.shape, targ_size[-2:])
|
||||||
b = attn_mask.unsqueeze(0).expand(*targ_size)
|
b = causal_lm_mask.unsqueeze(0).expand(*targ_size)
|
||||||
return (a + b).unsqueeze(1).clamp(LARGE_NEGATIVE,)
|
return (a + b).unsqueeze(1).clamp(LARGE_NEGATIVE,)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ if is_torch_available():
|
|||||||
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
|
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
shift_tokens_right,
|
shift_tokens_right,
|
||||||
_prepare_bart_decoder_inputs,
|
_prepare_bart_decoder_inputs,
|
||||||
|
LARGE_NEGATIVE,
|
||||||
)
|
)
|
||||||
from transformers.tokenization_bart import BartTokenizer
|
from transformers.tokenization_bart import BartTokenizer
|
||||||
|
|
||||||
@@ -303,6 +304,38 @@ class BartHeadTests(unittest.TestCase):
|
|||||||
lm_model = BartForConditionalGeneration(config).eval().to(torch_device).half()
|
lm_model = BartForConditionalGeneration(config).eval().to(torch_device).half()
|
||||||
lm_model.generate(input_ids, attention_mask)
|
lm_model.generate(input_ids, attention_mask)
|
||||||
|
|
||||||
|
def test_prepare_bart_decoder_inputs(self):
|
||||||
|
config, *_ = self._get_config_and_data(output_past=False)
|
||||||
|
input_ids = _long_tensor(([4, 4, 2])) # only used for .device if decoder_input_ids is passed
|
||||||
|
decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]])
|
||||||
|
ignore = LARGE_NEGATIVE
|
||||||
|
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(config, input_ids, decoder_input_ids)
|
||||||
|
expected_mask = torch.tensor(
|
||||||
|
[
|
||||||
|
[0, ignore, ignore],
|
||||||
|
[0, 0, ignore],
|
||||||
|
[ignore, ignore, ignore], # never attend to the final token, because its pad
|
||||||
|
]
|
||||||
|
).to(input_ids.device)
|
||||||
|
self.assertEqual(decoder_attn_mask.size(), (1, 1, 3, 3))
|
||||||
|
self.assertTrue(torch.eq(expected_mask, decoder_attn_mask).all())
|
||||||
|
|
||||||
|
# Test no causal mask
|
||||||
|
config, *_ = self._get_config_and_data(output_past=True)
|
||||||
|
expected_just_padding_mask = torch.tensor(
|
||||||
|
[[0, 0, 0], [0, 0, 0], [ignore, ignore, ignore]] # never attend to the final token, because its pad
|
||||||
|
).to(input_ids.device)
|
||||||
|
_, decoder_attn_mask_no_causal_mask = _prepare_bart_decoder_inputs(config, input_ids, decoder_input_ids)
|
||||||
|
self.assertEqual(decoder_attn_mask_no_causal_mask.size(), (1, 1, 3, 3))
|
||||||
|
self.assertTrue(torch.eq(expected_just_padding_mask, decoder_attn_mask_no_causal_mask).all())
|
||||||
|
|
||||||
|
decoder_input_ids = _long_tensor([[0, 26388, 4133, 2]])
|
||||||
|
# Attend to everything if no pad tokens and no causal mask
|
||||||
|
_, decoder_attn_mask_no_padding_no_causal_mask = _prepare_bart_decoder_inputs(
|
||||||
|
config, input_ids, decoder_input_ids
|
||||||
|
)
|
||||||
|
self.assertTrue(torch.eq(decoder_attn_mask_no_padding_no_causal_mask, 0).all())
|
||||||
|
|
||||||
|
|
||||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||||
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||||
|
|||||||
Reference in New Issue
Block a user