From 913d03dc5e78b82c24be7a52c9ad06dd1022f1e2 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 17 Nov 2023 10:15:00 +0000 Subject: [PATCH] Generate: fix flaky tests (#27543) --- src/transformers/generation/logits_process.py | 3 +- tests/generation/test_logits_process.py | 2 +- tests/generation/test_utils.py | 40 ++++++++----------- 3 files changed, 20 insertions(+), 25 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 3d1801b248..d1704ed020 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1301,8 +1301,9 @@ class InfNanRemoveLogitsProcessor(LogitsProcessor): # set all nan values to 0.0 scores[scores != scores] = 0.0 - # set all inf values to max possible value + # set all +/-inf values to max/min possible value scores[scores == float("inf")] = torch.finfo(scores.dtype).max + scores[scores == float("-inf")] = torch.finfo(scores.dtype).min return scores diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index 15f5cf1e4f..9e5ccd16eb 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -692,7 +692,7 @@ class LogitsProcessorTest(unittest.TestCase): torch.allclose( scores, torch.tensor( - [[0.0, 0.7, 0.8, 0.0], [0.1, torch.finfo(scores.dtype).max, 0.3, float("-inf")]], + [[0.0, 0.7, 0.8, 0.0], [0.1, torch.finfo(scores.dtype).max, 0.3, torch.finfo(scores.dtype).min]], device=torch_device, ), atol=1e-6, diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 1e76c88c71..729c7f8734 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -124,9 +124,14 @@ class GenerationTesterMixin: process_kwargs = { "min_length": input_length + 1 if max_length is None else max_length - 1, "bad_words_ids": [[1, 0]], - "no_repeat_ngram_size": 2, "repetition_penalty": 1.2, + "remove_invalid_values": True, } + # NoRepeatNGramLogitsProcessor + forced tokens may result in no valid continuations + if forced_bos_token_id is None and forced_eos_token_id is None: + process_kwargs["no_repeat_ngram_size"] = 2 + + # NOTE: the order of operations here should match `generate` for accurate testing logits_processor = LogitsProcessorList( ( [ @@ -154,12 +159,16 @@ class GenerationTesterMixin: if forced_eos_token_id is not None else [] ) - + [ - NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id), - NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"]), - RepetitionPenaltyLogitsProcessor(process_kwargs["repetition_penalty"]), - ] + + [NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id)] + + ( + [NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"])] + if forced_bos_token_id is None and forced_eos_token_id is None + else [] + ) + + [RepetitionPenaltyLogitsProcessor(process_kwargs["repetition_penalty"])] + + [InfNanRemoveLogitsProcessor()] # prevent flaky generation test failures ) + return process_kwargs, logits_processor @staticmethod @@ -282,7 +291,6 @@ class GenerationTesterMixin: output_hidden_states=output_hidden_states, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, - remove_invalid_values=True, **logits_process_kwargs, **model_kwargs, ) @@ -340,7 +348,6 @@ class GenerationTesterMixin: output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, - remove_invalid_values=True, **logits_warper_kwargs, **process_kwargs, **model_kwargs, @@ -361,9 +368,6 @@ class GenerationTesterMixin: elif attention_mask is not None: attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0) - # prevent flaky generation test failures - logits_processor.append(InfNanRemoveLogitsProcessor()) - with torch.no_grad(): model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_sample = model.sample( @@ -405,7 +409,6 @@ class GenerationTesterMixin: output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, - remove_invalid_values=True, **beam_kwargs, **logits_process_kwargs, **model_kwargs, @@ -467,7 +470,6 @@ class GenerationTesterMixin: output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, - remove_invalid_values=True, **beam_kwargs, **logits_warper_kwargs, **model_kwargs, @@ -534,7 +536,6 @@ class GenerationTesterMixin: output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, - remove_invalid_values=True, **beam_kwargs, **logits_process_kwargs, **model_kwargs, @@ -596,7 +597,6 @@ class GenerationTesterMixin: output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, - remove_invalid_values=True, constraints=constraints, **beam_kwargs, **logits_process_kwargs, @@ -671,7 +671,6 @@ class GenerationTesterMixin: output_hidden_states=output_hidden_states, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, - remove_invalid_values=True, **logits_process_kwargs, **model_kwargs, **contrastive_search_kwargs, @@ -1284,13 +1283,8 @@ class GenerationTesterMixin: # check `generate()` and `constrained_beam_search()` are equal # Sample constraints - if not input_ids.dtype == torch.float32: - min_id = torch.min(input_ids) + 3 - max_id = torch.max(input_ids) - else: - # otherwise this throws an error for Speech2TextModel since its inputs are floating points - min_id = 3 - max_id = 100 + min_id = 3 + max_id = config.vocab_size force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0] constraints = [