[Generate] Add save mode logits processor to remove nans and infs if necessary (#10769)

* push

* finish

* finish

* make fix copies

* change name
This commit is contained in:
Patrick von Platen
2021-03-23 01:00:05 +03:00
committed by GitHub
parent 9f8fa4e973
commit 77bf3fe787
9 changed files with 156 additions and 36 deletions

View File

@@ -31,6 +31,7 @@ if is_torch_available():
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitsProcessorList,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
@@ -436,3 +437,24 @@ class LogitsProcessorTest(unittest.TestCase):
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores)
self.assertFalse(torch.isinf(scores).any())
def test_remove_nan_inf_logits_processor(self):
scores = torch.tensor(
[[0.0, 0.7, 0.8, float("nan")], [0.1, float("inf"), 0.3, float("-inf")]], device=torch_device
)
input_ids = ids_tensor((2, 4), vocab_size=20)
logits_processor = InfNanRemoveLogitsProcessor()
scores = logits_processor(input_ids, scores)
self.assertTrue(
torch.allclose(
scores,
torch.tensor(
[[0.0, 0.7, 0.8, 0.0], [0.1, torch.finfo(scores.dtype).max, 0.3, float("-inf")]],
device=torch_device,
),
atol=1e-6,
)
)

View File

@@ -29,6 +29,7 @@ if is_torch_available():
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitsProcessorList,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
@@ -229,6 +230,7 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**logits_process_kwargs,
)
@@ -284,6 +286,7 @@ class GenerationTesterMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**logits_warper_kwargs,
**process_kwargs,
)
@@ -305,19 +308,23 @@ class GenerationTesterMixin:
attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, dim=0)
input_ids_clone = input_ids.repeat_interleave(num_return_sequences, dim=0)
# prevent flaky generation test failures
logits_processor.append(InfNanRemoveLogitsProcessor())
with torch.no_grad():
output_sample = model.sample(
input_ids_clone,
attention_mask=attention_mask_clone,
max_length=max_length,
logits_processor=logits_processor,
logits_warper=logits_warper,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
)
with torch.no_grad():
output_sample = model.sample(
input_ids_clone,
attention_mask=attention_mask_clone,
max_length=max_length,
logits_processor=logits_processor,
logits_warper=logits_warper,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
)
return output_sample, output_generate
def _beam_search_generate(
@@ -344,6 +351,7 @@ class GenerationTesterMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**beam_kwargs,
**logits_process_kwargs,
)
@@ -406,6 +414,7 @@ class GenerationTesterMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**beam_kwargs,
**logits_warper_kwargs,
)
@@ -424,6 +433,10 @@ class GenerationTesterMixin:
else:
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0)
# prevent flaky generation test failures
logits_processor = LogitsProcessorList()
logits_processor.append(InfNanRemoveLogitsProcessor())
torch.manual_seed(0)
with torch.no_grad():
output_beam_sample = model.beam_sample(
@@ -432,6 +445,7 @@ class GenerationTesterMixin:
max_length=max_length,
attention_mask=attention_mask,
logits_warper=logits_warper,
logits_processor=logits_processor,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
@@ -465,6 +479,7 @@ class GenerationTesterMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**beam_kwargs,
**logits_process_kwargs,
)
@@ -936,6 +951,7 @@ class GenerationTesterMixin:
output_ids_generate = model.generate(
do_sample=False,
max_length=max_length,
remove_invalid_values=True,
)
self.assertIsNotNone(output_ids_generate)
@@ -1309,7 +1325,12 @@ class GenerationIntegrationTests(unittest.TestCase):
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
outputs = bart_model.generate(
input_ids, num_beams=4, num_return_sequences=2, num_beam_groups=4, diversity_penalty=2.0
input_ids,
num_beams=4,
num_return_sequences=2,
num_beam_groups=4,
diversity_penalty=2.0,
remove_invalid_values=True,
)
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True)
@@ -1359,13 +1380,14 @@ class GenerationIntegrationTests(unittest.TestCase):
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
bart_model.sample(
input_ids,
max_length=max_length,
pad_token_id=bart_model.config.pad_token_id,
eos_token_id=bart_model.config.eos_token_id,
**model_kwargs,
)
with torch.no_grad():
bart_model.sample(
input_ids,
max_length=max_length,
pad_token_id=bart_model.config.pad_token_id,
eos_token_id=bart_model.config.eos_token_id,
**model_kwargs,
)
def test_max_length_backward_compat_beam_search(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
@@ -1463,14 +1485,15 @@ class GenerationIntegrationTests(unittest.TestCase):
# Sample
with self.assertWarns(UserWarning):
bart_model.sample(
input_ids,
max_length=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=bart_model.config.pad_token_id,
eos_token_id=bart_model.config.eos_token_id,
**model_kwargs,
)
with torch.no_grad():
bart_model.sample(
input_ids,
max_length=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=bart_model.config.pad_token_id,
eos_token_id=bart_model.config.eos_token_id,
**model_kwargs,
)
# Beam
beam_scorer = BeamSearchScorer(
@@ -1480,14 +1503,15 @@ class GenerationIntegrationTests(unittest.TestCase):
device=torch_device,
)
with self.assertWarns(UserWarning):
bart_model.beam_search(
input_ids,
num_beams=num_beams,
stopping_criteria=stopping_criteria,
max_length=max_length,
beam_scorer=beam_scorer,
**model_kwargs,
)
with torch.no_grad():
bart_model.beam_search(
input_ids,
num_beams=num_beams,
stopping_criteria=stopping_criteria,
max_length=max_length,
beam_scorer=beam_scorer,
**model_kwargs,
)
# Grouped beam search
diverse_beam_scorer = BeamSearchScorer(