Generate: fix flaky tests (#27543)
This commit is contained in:
@@ -1301,8 +1301,9 @@ class InfNanRemoveLogitsProcessor(LogitsProcessor):
|
|||||||
# set all nan values to 0.0
|
# set all nan values to 0.0
|
||||||
scores[scores != scores] = 0.0
|
scores[scores != scores] = 0.0
|
||||||
|
|
||||||
# set all inf values to max possible value
|
# set all +/-inf values to max/min possible value
|
||||||
scores[scores == float("inf")] = torch.finfo(scores.dtype).max
|
scores[scores == float("inf")] = torch.finfo(scores.dtype).max
|
||||||
|
scores[scores == float("-inf")] = torch.finfo(scores.dtype).min
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|||||||
@@ -692,7 +692,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
torch.allclose(
|
torch.allclose(
|
||||||
scores,
|
scores,
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
[[0.0, 0.7, 0.8, 0.0], [0.1, torch.finfo(scores.dtype).max, 0.3, float("-inf")]],
|
[[0.0, 0.7, 0.8, 0.0], [0.1, torch.finfo(scores.dtype).max, 0.3, torch.finfo(scores.dtype).min]],
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
),
|
),
|
||||||
atol=1e-6,
|
atol=1e-6,
|
||||||
|
|||||||
@@ -124,9 +124,14 @@ class GenerationTesterMixin:
|
|||||||
process_kwargs = {
|
process_kwargs = {
|
||||||
"min_length": input_length + 1 if max_length is None else max_length - 1,
|
"min_length": input_length + 1 if max_length is None else max_length - 1,
|
||||||
"bad_words_ids": [[1, 0]],
|
"bad_words_ids": [[1, 0]],
|
||||||
"no_repeat_ngram_size": 2,
|
|
||||||
"repetition_penalty": 1.2,
|
"repetition_penalty": 1.2,
|
||||||
|
"remove_invalid_values": True,
|
||||||
}
|
}
|
||||||
|
# NoRepeatNGramLogitsProcessor + forced tokens may result in no valid continuations
|
||||||
|
if forced_bos_token_id is None and forced_eos_token_id is None:
|
||||||
|
process_kwargs["no_repeat_ngram_size"] = 2
|
||||||
|
|
||||||
|
# NOTE: the order of operations here should match `generate` for accurate testing
|
||||||
logits_processor = LogitsProcessorList(
|
logits_processor = LogitsProcessorList(
|
||||||
(
|
(
|
||||||
[
|
[
|
||||||
@@ -154,12 +159,16 @@ class GenerationTesterMixin:
|
|||||||
if forced_eos_token_id is not None
|
if forced_eos_token_id is not None
|
||||||
else []
|
else []
|
||||||
)
|
)
|
||||||
+ [
|
+ [NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id)]
|
||||||
NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id),
|
+ (
|
||||||
NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"]),
|
[NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"])]
|
||||||
RepetitionPenaltyLogitsProcessor(process_kwargs["repetition_penalty"]),
|
if forced_bos_token_id is None and forced_eos_token_id is None
|
||||||
]
|
else []
|
||||||
|
)
|
||||||
|
+ [RepetitionPenaltyLogitsProcessor(process_kwargs["repetition_penalty"])]
|
||||||
|
+ [InfNanRemoveLogitsProcessor()] # prevent flaky generation test failures
|
||||||
)
|
)
|
||||||
|
|
||||||
return process_kwargs, logits_processor
|
return process_kwargs, logits_processor
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -282,7 +291,6 @@ class GenerationTesterMixin:
|
|||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
output_scores=output_scores,
|
output_scores=output_scores,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
remove_invalid_values=True,
|
|
||||||
**logits_process_kwargs,
|
**logits_process_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
@@ -340,7 +348,6 @@ class GenerationTesterMixin:
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
remove_invalid_values=True,
|
|
||||||
**logits_warper_kwargs,
|
**logits_warper_kwargs,
|
||||||
**process_kwargs,
|
**process_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
@@ -361,9 +368,6 @@ class GenerationTesterMixin:
|
|||||||
elif attention_mask is not None:
|
elif attention_mask is not None:
|
||||||
attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0)
|
attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0)
|
||||||
|
|
||||||
# prevent flaky generation test failures
|
|
||||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||||
output_sample = model.sample(
|
output_sample = model.sample(
|
||||||
@@ -405,7 +409,6 @@ class GenerationTesterMixin:
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
remove_invalid_values=True,
|
|
||||||
**beam_kwargs,
|
**beam_kwargs,
|
||||||
**logits_process_kwargs,
|
**logits_process_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
@@ -467,7 +470,6 @@ class GenerationTesterMixin:
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
remove_invalid_values=True,
|
|
||||||
**beam_kwargs,
|
**beam_kwargs,
|
||||||
**logits_warper_kwargs,
|
**logits_warper_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
@@ -534,7 +536,6 @@ class GenerationTesterMixin:
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
remove_invalid_values=True,
|
|
||||||
**beam_kwargs,
|
**beam_kwargs,
|
||||||
**logits_process_kwargs,
|
**logits_process_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
@@ -596,7 +597,6 @@ class GenerationTesterMixin:
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
remove_invalid_values=True,
|
|
||||||
constraints=constraints,
|
constraints=constraints,
|
||||||
**beam_kwargs,
|
**beam_kwargs,
|
||||||
**logits_process_kwargs,
|
**logits_process_kwargs,
|
||||||
@@ -671,7 +671,6 @@ class GenerationTesterMixin:
|
|||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
output_scores=output_scores,
|
output_scores=output_scores,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
remove_invalid_values=True,
|
|
||||||
**logits_process_kwargs,
|
**logits_process_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
**contrastive_search_kwargs,
|
**contrastive_search_kwargs,
|
||||||
@@ -1284,13 +1283,8 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
# check `generate()` and `constrained_beam_search()` are equal
|
# check `generate()` and `constrained_beam_search()` are equal
|
||||||
# Sample constraints
|
# Sample constraints
|
||||||
if not input_ids.dtype == torch.float32:
|
min_id = 3
|
||||||
min_id = torch.min(input_ids) + 3
|
max_id = config.vocab_size
|
||||||
max_id = torch.max(input_ids)
|
|
||||||
else:
|
|
||||||
# otherwise this throws an error for Speech2TextModel since its inputs are floating points
|
|
||||||
min_id = 3
|
|
||||||
max_id = 100
|
|
||||||
|
|
||||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
||||||
constraints = [
|
constraints = [
|
||||||
|
|||||||
Reference in New Issue
Block a user