Generate: fix flaky tests (#27543)
This commit is contained in:
@@ -692,7 +692,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
torch.allclose(
|
||||
scores,
|
||||
torch.tensor(
|
||||
[[0.0, 0.7, 0.8, 0.0], [0.1, torch.finfo(scores.dtype).max, 0.3, float("-inf")]],
|
||||
[[0.0, 0.7, 0.8, 0.0], [0.1, torch.finfo(scores.dtype).max, 0.3, torch.finfo(scores.dtype).min]],
|
||||
device=torch_device,
|
||||
),
|
||||
atol=1e-6,
|
||||
|
||||
@@ -124,9 +124,14 @@ class GenerationTesterMixin:
|
||||
process_kwargs = {
|
||||
"min_length": input_length + 1 if max_length is None else max_length - 1,
|
||||
"bad_words_ids": [[1, 0]],
|
||||
"no_repeat_ngram_size": 2,
|
||||
"repetition_penalty": 1.2,
|
||||
"remove_invalid_values": True,
|
||||
}
|
||||
# NoRepeatNGramLogitsProcessor + forced tokens may result in no valid continuations
|
||||
if forced_bos_token_id is None and forced_eos_token_id is None:
|
||||
process_kwargs["no_repeat_ngram_size"] = 2
|
||||
|
||||
# NOTE: the order of operations here should match `generate` for accurate testing
|
||||
logits_processor = LogitsProcessorList(
|
||||
(
|
||||
[
|
||||
@@ -154,12 +159,16 @@ class GenerationTesterMixin:
|
||||
if forced_eos_token_id is not None
|
||||
else []
|
||||
)
|
||||
+ [
|
||||
NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id),
|
||||
NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"]),
|
||||
RepetitionPenaltyLogitsProcessor(process_kwargs["repetition_penalty"]),
|
||||
]
|
||||
+ [NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id)]
|
||||
+ (
|
||||
[NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"])]
|
||||
if forced_bos_token_id is None and forced_eos_token_id is None
|
||||
else []
|
||||
)
|
||||
+ [RepetitionPenaltyLogitsProcessor(process_kwargs["repetition_penalty"])]
|
||||
+ [InfNanRemoveLogitsProcessor()] # prevent flaky generation test failures
|
||||
)
|
||||
|
||||
return process_kwargs, logits_processor
|
||||
|
||||
@staticmethod
|
||||
@@ -282,7 +291,6 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_scores=output_scores,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
**logits_process_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
@@ -340,7 +348,6 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
**logits_warper_kwargs,
|
||||
**process_kwargs,
|
||||
**model_kwargs,
|
||||
@@ -361,9 +368,6 @@ class GenerationTesterMixin:
|
||||
elif attention_mask is not None:
|
||||
attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0)
|
||||
|
||||
# prevent flaky generation test failures
|
||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||
|
||||
with torch.no_grad():
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_sample = model.sample(
|
||||
@@ -405,7 +409,6 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
**beam_kwargs,
|
||||
**logits_process_kwargs,
|
||||
**model_kwargs,
|
||||
@@ -467,7 +470,6 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
**beam_kwargs,
|
||||
**logits_warper_kwargs,
|
||||
**model_kwargs,
|
||||
@@ -534,7 +536,6 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
**beam_kwargs,
|
||||
**logits_process_kwargs,
|
||||
**model_kwargs,
|
||||
@@ -596,7 +597,6 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
constraints=constraints,
|
||||
**beam_kwargs,
|
||||
**logits_process_kwargs,
|
||||
@@ -671,7 +671,6 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_scores=output_scores,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
**logits_process_kwargs,
|
||||
**model_kwargs,
|
||||
**contrastive_search_kwargs,
|
||||
@@ -1284,13 +1283,8 @@ class GenerationTesterMixin:
|
||||
|
||||
# check `generate()` and `constrained_beam_search()` are equal
|
||||
# Sample constraints
|
||||
if not input_ids.dtype == torch.float32:
|
||||
min_id = torch.min(input_ids) + 3
|
||||
max_id = torch.max(input_ids)
|
||||
else:
|
||||
# otherwise this throws an error for Speech2TextModel since its inputs are floating points
|
||||
min_id = 3
|
||||
max_id = 100
|
||||
min_id = 3
|
||||
max_id = config.vocab_size
|
||||
|
||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
||||
constraints = [
|
||||
|
||||
Reference in New Issue
Block a user