From 77bf3fe787b454aceac4a2ef88b147180f8828fb Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 23 Mar 2021 01:00:05 +0300 Subject: [PATCH] [Generate] Add save mode logits processor to remove nans and infs if necessary (#10769) * push * finish * finish * make fix copies * change name --- docs/source/internal/generation_utils.rst | 10 ++ src/transformers/__init__.py | 6 ++ src/transformers/configuration_utils.py | 4 + src/transformers/generation_logits_process.py | 17 ++++ src/transformers/generation_utils.py | 14 +++ src/transformers/models/rag/modeling_rag.py | 8 ++ src/transformers/utils/dummy_pt_objects.py | 15 +++ tests/test_generation_logits_process.py | 22 +++++ tests/test_generation_utils.py | 96 ++++++++++++------- 9 files changed, 156 insertions(+), 36 deletions(-) diff --git a/docs/source/internal/generation_utils.rst b/docs/source/internal/generation_utils.rst index 25fc82cbbe..9051a44721 100644 --- a/docs/source/internal/generation_utils.rst +++ b/docs/source/internal/generation_utils.rst @@ -151,6 +151,16 @@ generation. .. autoclass:: transformers.HammingDiversityLogitsProcessor :members: __call__ +.. autoclass:: transformers.ForcedBOSTokenLogitsProcessor + :members: __call__ + +.. autoclass:: transformers.ForcedEOSTokenLogitsProcessor + :members: __call__ + +.. autoclass:: transformers.InfNanRemoveLogitsProcessor + :members: __call__ + + StoppingCriteria ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 5d8aa3e427..fe5ff901aa 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -369,7 +369,10 @@ if is_torch_available(): ] _import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer"] _import_structure["generation_logits_process"] = [ + "ForcedBOSTokenLogitsProcessor", + "ForcedEOSTokenLogitsProcessor", "HammingDiversityLogitsProcessor", + "InfNanRemoveLogitsProcessor", "LogitsProcessor", "LogitsProcessorList", "LogitsWarper", @@ -1560,7 +1563,10 @@ if TYPE_CHECKING: ) from .generation_beam_search import BeamScorer, BeamSearchScorer from .generation_logits_process import ( + ForcedBOSTokenLogitsProcessor, + ForcedEOSTokenLogitsProcessor, HammingDiversityLogitsProcessor, + InfNanRemoveLogitsProcessor, LogitsProcessor, LogitsProcessorList, LogitsWarper, diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index c6830f5083..1c428eae5c 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -134,6 +134,9 @@ class PretrainedConfig(object): <../model_doc/mbart>` where the first generated token needs to be the target language token. - **forced_eos_token_id** (:obj:`int`, `optional`) -- The id of the token to force as the last generated token when :obj:`max_length` is reached. + - **remove_invalid_values** (:obj:`bool`, `optional`) -- Whether to remove possible `nan` and `inf` outputs of + the model to prevent the generation method to crash. Note that using ``remove_invalid_values`` can slow down + generation. Parameters for fine-tuning tasks @@ -219,6 +222,7 @@ class PretrainedConfig(object): self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False) self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None) self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None) + self.remove_invalid_values = kwargs.pop("remove_invalid_values", False) # Fine-tuning task arguments self.architectures = kwargs.pop("architectures", None) diff --git a/src/transformers/generation_logits_process.py b/src/transformers/generation_logits_process.py index 8d42aba12a..a2fa58d6f7 100644 --- a/src/transformers/generation_logits_process.py +++ b/src/transformers/generation_logits_process.py @@ -566,3 +566,20 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor): scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf") scores[:, self.eos_token_id] = 0 return scores + + +class InfNanRemoveLogitsProcessor(LogitsProcessor): + r""" + :class:`~transformers.LogitsProcessor` that removes all :obj:`nan` and :obj:`inf` values to avoid the generation + method to fail. Note that using the logits processor should only be used if necessary since it can slow down the + generation method. :obj:`max_length` is reached. + """ + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + # set all nan values to 0.0 + scores[scores != scores] = 0.0 + + # set all inf values to max possible value + scores[scores == float("inf")] = torch.finfo(scores.dtype).max + + return scores diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 85f0afe5c6..e5aea93944 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -27,6 +27,7 @@ from .generation_logits_process import ( ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, HammingDiversityLogitsProcessor, + InfNanRemoveLogitsProcessor, LogitsProcessorList, MinLengthLogitsProcessor, NoBadWordsLogitsProcessor, @@ -581,6 +582,7 @@ class GenerationMixin: num_beams: int, num_beam_groups: int, diversity_penalty: float, + remove_invalid_values: bool, ) -> LogitsProcessorList: """ This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant @@ -607,6 +609,9 @@ class GenerationMixin: forced_eos_token_id = ( forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id ) + remove_invalid_values = ( + remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values + ) # instantiate processors list processors = LogitsProcessorList() @@ -639,6 +644,8 @@ class GenerationMixin: processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id)) if forced_eos_token_id is not None: processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)) + if remove_invalid_values is True: + processors.append(InfNanRemoveLogitsProcessor()) return processors def _get_stopping_criteria( @@ -687,6 +694,7 @@ class GenerationMixin: return_dict_in_generate: Optional[bool] = None, forced_bos_token_id: Optional[int] = None, forced_eos_token_id: Optional[int] = None, + remove_invalid_values: Optional[bool] = None, **model_kwargs, ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: r""" @@ -789,6 +797,9 @@ class GenerationMixin: needs to be the target language token. forced_eos_token_id (:obj:`int`, `optional`): The id of the token to force as the last generated token when :obj:`max_length` is reached. + remove_invalid_values (:obj:`bool`, `optional`): + Whether to remove possible `nan` and `inf` outputs of the model to prevent the generation method to + crash. Note that using ``remove_invalid_values`` can slow down generation. model_kwargs: Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the @@ -965,6 +976,7 @@ class GenerationMixin: num_beams=num_beams, num_beam_groups=num_beam_groups, diversity_penalty=diversity_penalty, + remove_invalid_values=remove_invalid_values, ) stopping_criteria = self._get_stopping_criteria( @@ -1511,6 +1523,7 @@ class GenerationMixin: # sample probs = F.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # add code that transfomers next_tokens to tokens_to_add @@ -2026,6 +2039,7 @@ class GenerationMixin: next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) probs = F.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) next_token_scores = torch.gather(next_token_scores, -1, next_tokens) diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 5e9e8c356a..ae735926b2 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1316,6 +1316,7 @@ class RagTokenForGeneration(RagPreTrainedModel): prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None, forced_bos_token_id: Optional[int] = None, forced_eos_token_id: Optional[int] = None, + remove_invalid_values: Optional[bool] = None, **model_kwargs ): """ @@ -1412,6 +1413,9 @@ class RagTokenForGeneration(RagPreTrainedModel): needs to be the target language token. forced_eos_token_id (:obj:`int`, `optional`): The id of the token to force as the last generated token when :obj:`max_length` is reached. + remove_invalid_values (:obj:`bool`, `optional`): + Whether to remove possible `nan` and `inf` outputs of the model to prevent the generation method to + crash. Note that using ``remove_invalid_values`` can slow down generation. Return: :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated @@ -1435,6 +1439,9 @@ class RagTokenForGeneration(RagPreTrainedModel): if decoder_start_token_id is not None else self.config.generator.decoder_start_token_id ) + remove_invalid_values = ( + remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values + ) # retrieve docs if self.retriever is not None and context_input_ids is None: @@ -1515,6 +1522,7 @@ class RagTokenForGeneration(RagPreTrainedModel): num_beams=num_beams, num_beam_groups=num_beam_groups, diversity_penalty=diversity_penalty, + remove_invalid_values=remove_invalid_values, ) if num_beams == 1: diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index d5ddcd2e3c..00a84b6810 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -123,11 +123,26 @@ class BeamSearchScorer: requires_pytorch(self) +class ForcedBOSTokenLogitsProcessor: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + +class ForcedEOSTokenLogitsProcessor: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + class HammingDiversityLogitsProcessor: def __init__(self, *args, **kwargs): requires_pytorch(self) +class InfNanRemoveLogitsProcessor: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + class LogitsProcessor: def __init__(self, *args, **kwargs): requires_pytorch(self) diff --git a/tests/test_generation_logits_process.py b/tests/test_generation_logits_process.py index 85a589b7c2..2e00be0fa4 100644 --- a/tests/test_generation_logits_process.py +++ b/tests/test_generation_logits_process.py @@ -31,6 +31,7 @@ if is_torch_available(): ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, HammingDiversityLogitsProcessor, + InfNanRemoveLogitsProcessor, LogitsProcessorList, MinLengthLogitsProcessor, NoBadWordsLogitsProcessor, @@ -436,3 +437,24 @@ class LogitsProcessorTest(unittest.TestCase): scores = self._get_uniform_logits(batch_size, vocab_size) scores = logits_processor(input_ids, scores) self.assertFalse(torch.isinf(scores).any()) + + def test_remove_nan_inf_logits_processor(self): + scores = torch.tensor( + [[0.0, 0.7, 0.8, float("nan")], [0.1, float("inf"), 0.3, float("-inf")]], device=torch_device + ) + input_ids = ids_tensor((2, 4), vocab_size=20) + + logits_processor = InfNanRemoveLogitsProcessor() + + scores = logits_processor(input_ids, scores) + + self.assertTrue( + torch.allclose( + scores, + torch.tensor( + [[0.0, 0.7, 0.8, 0.0], [0.1, torch.finfo(scores.dtype).max, 0.3, float("-inf")]], + device=torch_device, + ), + atol=1e-6, + ) + ) diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index 6dc72fbc47..6b84a42e07 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -29,6 +29,7 @@ if is_torch_available(): ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, HammingDiversityLogitsProcessor, + InfNanRemoveLogitsProcessor, LogitsProcessorList, MinLengthLogitsProcessor, NoBadWordsLogitsProcessor, @@ -229,6 +230,7 @@ class GenerationTesterMixin: output_hidden_states=output_hidden_states, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, + remove_invalid_values=True, **logits_process_kwargs, ) @@ -284,6 +286,7 @@ class GenerationTesterMixin: output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, + remove_invalid_values=True, **logits_warper_kwargs, **process_kwargs, ) @@ -305,19 +308,23 @@ class GenerationTesterMixin: attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, dim=0) input_ids_clone = input_ids.repeat_interleave(num_return_sequences, dim=0) + # prevent flaky generation test failures + logits_processor.append(InfNanRemoveLogitsProcessor()) + with torch.no_grad(): - output_sample = model.sample( - input_ids_clone, - attention_mask=attention_mask_clone, - max_length=max_length, - logits_processor=logits_processor, - logits_warper=logits_warper, - output_scores=output_scores, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict_in_generate=return_dict_in_generate, - **kwargs, - ) + with torch.no_grad(): + output_sample = model.sample( + input_ids_clone, + attention_mask=attention_mask_clone, + max_length=max_length, + logits_processor=logits_processor, + logits_warper=logits_warper, + output_scores=output_scores, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict_in_generate=return_dict_in_generate, + **kwargs, + ) return output_sample, output_generate def _beam_search_generate( @@ -344,6 +351,7 @@ class GenerationTesterMixin: output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, + remove_invalid_values=True, **beam_kwargs, **logits_process_kwargs, ) @@ -406,6 +414,7 @@ class GenerationTesterMixin: output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, + remove_invalid_values=True, **beam_kwargs, **logits_warper_kwargs, ) @@ -424,6 +433,10 @@ class GenerationTesterMixin: else: attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0) + # prevent flaky generation test failures + logits_processor = LogitsProcessorList() + logits_processor.append(InfNanRemoveLogitsProcessor()) + torch.manual_seed(0) with torch.no_grad(): output_beam_sample = model.beam_sample( @@ -432,6 +445,7 @@ class GenerationTesterMixin: max_length=max_length, attention_mask=attention_mask, logits_warper=logits_warper, + logits_processor=logits_processor, output_scores=output_scores, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -465,6 +479,7 @@ class GenerationTesterMixin: output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, + remove_invalid_values=True, **beam_kwargs, **logits_process_kwargs, ) @@ -936,6 +951,7 @@ class GenerationTesterMixin: output_ids_generate = model.generate( do_sample=False, max_length=max_length, + remove_invalid_values=True, ) self.assertIsNotNone(output_ids_generate) @@ -1309,7 +1325,12 @@ class GenerationIntegrationTests(unittest.TestCase): input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) outputs = bart_model.generate( - input_ids, num_beams=4, num_return_sequences=2, num_beam_groups=4, diversity_penalty=2.0 + input_ids, + num_beams=4, + num_return_sequences=2, + num_beam_groups=4, + diversity_penalty=2.0, + remove_invalid_values=True, ) generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True) @@ -1359,13 +1380,14 @@ class GenerationIntegrationTests(unittest.TestCase): decoder_start_token_id=bart_model.config.decoder_start_token_id, bos_token_id=bart_model.config.bos_token_id, ) - bart_model.sample( - input_ids, - max_length=max_length, - pad_token_id=bart_model.config.pad_token_id, - eos_token_id=bart_model.config.eos_token_id, - **model_kwargs, - ) + with torch.no_grad(): + bart_model.sample( + input_ids, + max_length=max_length, + pad_token_id=bart_model.config.pad_token_id, + eos_token_id=bart_model.config.eos_token_id, + **model_kwargs, + ) def test_max_length_backward_compat_beam_search(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" @@ -1463,14 +1485,15 @@ class GenerationIntegrationTests(unittest.TestCase): # Sample with self.assertWarns(UserWarning): - bart_model.sample( - input_ids, - max_length=max_length, - stopping_criteria=stopping_criteria, - pad_token_id=bart_model.config.pad_token_id, - eos_token_id=bart_model.config.eos_token_id, - **model_kwargs, - ) + with torch.no_grad(): + bart_model.sample( + input_ids, + max_length=max_length, + stopping_criteria=stopping_criteria, + pad_token_id=bart_model.config.pad_token_id, + eos_token_id=bart_model.config.eos_token_id, + **model_kwargs, + ) # Beam beam_scorer = BeamSearchScorer( @@ -1480,14 +1503,15 @@ class GenerationIntegrationTests(unittest.TestCase): device=torch_device, ) with self.assertWarns(UserWarning): - bart_model.beam_search( - input_ids, - num_beams=num_beams, - stopping_criteria=stopping_criteria, - max_length=max_length, - beam_scorer=beam_scorer, - **model_kwargs, - ) + with torch.no_grad(): + bart_model.beam_search( + input_ids, + num_beams=num_beams, + stopping_criteria=stopping_criteria, + max_length=max_length, + beam_scorer=beam_scorer, + **model_kwargs, + ) # Grouped beam search diverse_beam_scorer = BeamSearchScorer(