[Bart] _prepare_decoder_inputs should use large negative (#3158)
This commit is contained in:
@@ -65,7 +65,7 @@ BART_INPUTS_DOCSTRING = r"""
|
||||
If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
|
||||
See diagram 1 in the paper for more info on the default strategy
|
||||
"""
|
||||
LARGE_NEGATIVE = -1e4
|
||||
LARGE_NEGATIVE = -1e8
|
||||
|
||||
|
||||
def _prepare_bart_decoder_inputs(
|
||||
@@ -144,18 +144,18 @@ def _check_shapes(shape_1, shape2):
|
||||
raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2))
|
||||
|
||||
|
||||
def _combine_masks(key_padding_mask, attn_mask, targ_size):
|
||||
def _combine_masks(key_padding_mask, causal_lm_mask, targ_size):
|
||||
# targ_size = (bsz, tgt_len, src_len)
|
||||
a = torch.zeros(targ_size)
|
||||
b = torch.zeros(targ_size)
|
||||
if key_padding_mask is not None: # (bsz, tgt_len) -> targ_size
|
||||
_check_shapes(key_padding_mask.shape, targ_size[:2])
|
||||
reshaped = key_padding_mask.unsqueeze(2).expand(*targ_size)
|
||||
a[reshaped] = 1e-8
|
||||
a[reshaped] = LARGE_NEGATIVE
|
||||
|
||||
if attn_mask is not None: # (tgt_len, src_len) -> targ_size
|
||||
_check_shapes(attn_mask.shape, targ_size[-2:])
|
||||
b = attn_mask.unsqueeze(0).expand(*targ_size)
|
||||
if causal_lm_mask is not None: # (tgt_len, src_len) -> targ_size
|
||||
_check_shapes(causal_lm_mask.shape, targ_size[-2:])
|
||||
b = causal_lm_mask.unsqueeze(0).expand(*targ_size)
|
||||
return (a + b).unsqueeze(1).clamp(LARGE_NEGATIVE,)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user