Not use -1e4 as attn mask (#17306)
* Use torch.finfo(self.dtype).min * for GPTNeoX * for Albert * For Splinter * Update src/transformers/models/data2vec/modeling_data2vec_audio.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * fix -inf used in Bart-like models * Fix a few remaining -inf * more fix * clean up * For CLIP * For FSMT * clean up * fix test * Add dtype argument and use it for LayoutLMv3 * update FlaxLongT5Attention Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -351,9 +351,10 @@ class FSMTHeadTests(unittest.TestCase):
|
||||
config, *_ = self._get_config_and_data()
|
||||
input_ids = _long_tensor(([4, 4, 2]))
|
||||
decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]])
|
||||
ignore = float("-inf")
|
||||
causal_mask_dtype = torch.float32
|
||||
ignore = torch.finfo(causal_mask_dtype).min
|
||||
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_fsmt_decoder_inputs(
|
||||
config, input_ids, decoder_input_ids
|
||||
config, input_ids, decoder_input_ids, causal_mask_dtype=causal_mask_dtype
|
||||
)
|
||||
expected_causal_mask = torch.tensor(
|
||||
[[0, ignore, ignore], [0, 0, ignore], [0, 0, 0]] # never attend to the final token, because its pad
|
||||
|
||||
Reference in New Issue
Block a user