[Generate] Add save mode logits processor to remove nans and infs if necessary (#10769)
* push * finish * finish * make fix copies * change name
This commit is contained in:
committed by
GitHub
parent
9f8fa4e973
commit
77bf3fe787
@@ -31,6 +31,7 @@ if is_torch_available():
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
HammingDiversityLogitsProcessor,
|
||||
InfNanRemoveLogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
MinLengthLogitsProcessor,
|
||||
NoBadWordsLogitsProcessor,
|
||||
@@ -436,3 +437,24 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores)
|
||||
self.assertFalse(torch.isinf(scores).any())
|
||||
|
||||
def test_remove_nan_inf_logits_processor(self):
|
||||
scores = torch.tensor(
|
||||
[[0.0, 0.7, 0.8, float("nan")], [0.1, float("inf"), 0.3, float("-inf")]], device=torch_device
|
||||
)
|
||||
input_ids = ids_tensor((2, 4), vocab_size=20)
|
||||
|
||||
logits_processor = InfNanRemoveLogitsProcessor()
|
||||
|
||||
scores = logits_processor(input_ids, scores)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
scores,
|
||||
torch.tensor(
|
||||
[[0.0, 0.7, 0.8, 0.0], [0.1, torch.finfo(scores.dtype).max, 0.3, float("-inf")]],
|
||||
device=torch_device,
|
||||
),
|
||||
atol=1e-6,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -29,6 +29,7 @@ if is_torch_available():
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
HammingDiversityLogitsProcessor,
|
||||
InfNanRemoveLogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
MinLengthLogitsProcessor,
|
||||
NoBadWordsLogitsProcessor,
|
||||
@@ -229,6 +230,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_scores=output_scores,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
**logits_process_kwargs,
|
||||
)
|
||||
|
||||
@@ -284,6 +286,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
**logits_warper_kwargs,
|
||||
**process_kwargs,
|
||||
)
|
||||
@@ -305,19 +308,23 @@ class GenerationTesterMixin:
|
||||
attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, dim=0)
|
||||
input_ids_clone = input_ids.repeat_interleave(num_return_sequences, dim=0)
|
||||
|
||||
# prevent flaky generation test failures
|
||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||
|
||||
with torch.no_grad():
|
||||
output_sample = model.sample(
|
||||
input_ids_clone,
|
||||
attention_mask=attention_mask_clone,
|
||||
max_length=max_length,
|
||||
logits_processor=logits_processor,
|
||||
logits_warper=logits_warper,
|
||||
output_scores=output_scores,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**kwargs,
|
||||
)
|
||||
with torch.no_grad():
|
||||
output_sample = model.sample(
|
||||
input_ids_clone,
|
||||
attention_mask=attention_mask_clone,
|
||||
max_length=max_length,
|
||||
logits_processor=logits_processor,
|
||||
logits_warper=logits_warper,
|
||||
output_scores=output_scores,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**kwargs,
|
||||
)
|
||||
return output_sample, output_generate
|
||||
|
||||
def _beam_search_generate(
|
||||
@@ -344,6 +351,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
**beam_kwargs,
|
||||
**logits_process_kwargs,
|
||||
)
|
||||
@@ -406,6 +414,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
**beam_kwargs,
|
||||
**logits_warper_kwargs,
|
||||
)
|
||||
@@ -424,6 +433,10 @@ class GenerationTesterMixin:
|
||||
else:
|
||||
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0)
|
||||
|
||||
# prevent flaky generation test failures
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||
|
||||
torch.manual_seed(0)
|
||||
with torch.no_grad():
|
||||
output_beam_sample = model.beam_sample(
|
||||
@@ -432,6 +445,7 @@ class GenerationTesterMixin:
|
||||
max_length=max_length,
|
||||
attention_mask=attention_mask,
|
||||
logits_warper=logits_warper,
|
||||
logits_processor=logits_processor,
|
||||
output_scores=output_scores,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@@ -465,6 +479,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
**beam_kwargs,
|
||||
**logits_process_kwargs,
|
||||
)
|
||||
@@ -936,6 +951,7 @@ class GenerationTesterMixin:
|
||||
output_ids_generate = model.generate(
|
||||
do_sample=False,
|
||||
max_length=max_length,
|
||||
remove_invalid_values=True,
|
||||
)
|
||||
|
||||
self.assertIsNotNone(output_ids_generate)
|
||||
@@ -1309,7 +1325,12 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
outputs = bart_model.generate(
|
||||
input_ids, num_beams=4, num_return_sequences=2, num_beam_groups=4, diversity_penalty=2.0
|
||||
input_ids,
|
||||
num_beams=4,
|
||||
num_return_sequences=2,
|
||||
num_beam_groups=4,
|
||||
diversity_penalty=2.0,
|
||||
remove_invalid_values=True,
|
||||
)
|
||||
|
||||
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
@@ -1359,13 +1380,14 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
bart_model.sample(
|
||||
input_ids,
|
||||
max_length=max_length,
|
||||
pad_token_id=bart_model.config.pad_token_id,
|
||||
eos_token_id=bart_model.config.eos_token_id,
|
||||
**model_kwargs,
|
||||
)
|
||||
with torch.no_grad():
|
||||
bart_model.sample(
|
||||
input_ids,
|
||||
max_length=max_length,
|
||||
pad_token_id=bart_model.config.pad_token_id,
|
||||
eos_token_id=bart_model.config.eos_token_id,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
def test_max_length_backward_compat_beam_search(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
@@ -1463,14 +1485,15 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
# Sample
|
||||
with self.assertWarns(UserWarning):
|
||||
bart_model.sample(
|
||||
input_ids,
|
||||
max_length=max_length,
|
||||
stopping_criteria=stopping_criteria,
|
||||
pad_token_id=bart_model.config.pad_token_id,
|
||||
eos_token_id=bart_model.config.eos_token_id,
|
||||
**model_kwargs,
|
||||
)
|
||||
with torch.no_grad():
|
||||
bart_model.sample(
|
||||
input_ids,
|
||||
max_length=max_length,
|
||||
stopping_criteria=stopping_criteria,
|
||||
pad_token_id=bart_model.config.pad_token_id,
|
||||
eos_token_id=bart_model.config.eos_token_id,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# Beam
|
||||
beam_scorer = BeamSearchScorer(
|
||||
@@ -1480,14 +1503,15 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
device=torch_device,
|
||||
)
|
||||
with self.assertWarns(UserWarning):
|
||||
bart_model.beam_search(
|
||||
input_ids,
|
||||
num_beams=num_beams,
|
||||
stopping_criteria=stopping_criteria,
|
||||
max_length=max_length,
|
||||
beam_scorer=beam_scorer,
|
||||
**model_kwargs,
|
||||
)
|
||||
with torch.no_grad():
|
||||
bart_model.beam_search(
|
||||
input_ids,
|
||||
num_beams=num_beams,
|
||||
stopping_criteria=stopping_criteria,
|
||||
max_length=max_length,
|
||||
beam_scorer=beam_scorer,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# Grouped beam search
|
||||
diverse_beam_scorer = BeamSearchScorer(
|
||||
|
||||
Reference in New Issue
Block a user