[Bart] _prepare_decoder_inputs should use large negative (#3158)

This commit is contained in:
Sam Shleifer
2020-03-06 16:06:36 -05:00
committed by GitHub
parent 0416d437fb
commit ed37f9fa4f
2 changed files with 39 additions and 6 deletions

View File

@@ -65,7 +65,7 @@ BART_INPUTS_DOCSTRING = r"""
If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
See diagram 1 in the paper for more info on the default strategy
"""
LARGE_NEGATIVE = -1e4
LARGE_NEGATIVE = -1e8
def _prepare_bart_decoder_inputs(
@@ -144,18 +144,18 @@ def _check_shapes(shape_1, shape2):
raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2))
def _combine_masks(key_padding_mask, attn_mask, targ_size):
def _combine_masks(key_padding_mask, causal_lm_mask, targ_size):
# targ_size = (bsz, tgt_len, src_len)
a = torch.zeros(targ_size)
b = torch.zeros(targ_size)
if key_padding_mask is not None: # (bsz, tgt_len) -> targ_size
_check_shapes(key_padding_mask.shape, targ_size[:2])
reshaped = key_padding_mask.unsqueeze(2).expand(*targ_size)
a[reshaped] = 1e-8
a[reshaped] = LARGE_NEGATIVE
if attn_mask is not None: # (tgt_len, src_len) -> targ_size
_check_shapes(attn_mask.shape, targ_size[-2:])
b = attn_mask.unsqueeze(0).expand(*targ_size)
if causal_lm_mask is not None: # (tgt_len, src_len) -> targ_size
_check_shapes(causal_lm_mask.shape, targ_size[-2:])
b = causal_lm_mask.unsqueeze(0).expand(*targ_size)
return (a + b).unsqueeze(1).clamp(LARGE_NEGATIVE,)