From f7196f2e63b14e9fbb4ad664e71912aab3b484cf Mon Sep 17 00:00:00 2001 From: Santiago Castro Date: Wed, 13 Apr 2022 04:37:33 -0400 Subject: [PATCH] Fix decoding score comparison when using logits processors or warpers (#10638) * Normalize using a logits warper * Add a flag in `generate` to support the logit renormalization * Add in RAG --- src/transformers/generation_logits_process.py | 13 +++++++++ src/transformers/generation_utils.py | 29 +++++++++++++++++-- src/transformers/models/rag/modeling_rag.py | 2 ++ .../test_generation_logits_process.py | 16 ++++++++++ 4 files changed, 58 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation_logits_process.py b/src/transformers/generation_logits_process.py index 1b0762079d..7aa4004913 100644 --- a/src/transformers/generation_logits_process.py +++ b/src/transformers/generation_logits_process.py @@ -679,3 +679,16 @@ class ExponentialDecayLengthPenalty(LogitsProcessor): self.regulation_factor, cur_len - self.regulation_start ) return scores + + +class LogitNormalization(LogitsProcessor, LogitsWarper): + r""" + [`LogitsWarper`] and [`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize + the scores during beam search, after applying the logits processors or warpers, since the search algorithm used in + this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that + the scores are normalized when comparing the hypotheses. + """ + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + scores = scores.log_softmax(dim=-1) + return scores diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index b37e11af03..1bdcd06f0d 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -32,6 +32,7 @@ from .generation_logits_process import ( ForcedEOSTokenLogitsProcessor, HammingDiversityLogitsProcessor, InfNanRemoveLogitsProcessor, + LogitNormalization, LogitsProcessorList, MinLengthLogitsProcessor, NoBadWordsLogitsProcessor, @@ -636,6 +637,7 @@ class GenerationMixin: typical_p: Optional[float] = None, temperature: Optional[float] = None, num_beams: Optional[int] = None, + renormalize_logits: Optional[bool] = None, ) -> LogitsProcessorList: """ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances @@ -660,6 +662,9 @@ class GenerationMixin: warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) if typical_p is not None and typical_p < 1.0: warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) + # `LogitNormalization` should always be the last logit processor, when present + if renormalize_logits is True: + warpers.append(LogitNormalization()) return warpers def _get_logits_processor( @@ -682,6 +687,7 @@ class GenerationMixin: remove_invalid_values: bool, exponential_decay_length_penalty: Tuple, logits_processor: Optional[LogitsProcessorList], + renormalize_logits: Optional[bool], ) -> LogitsProcessorList: """ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`] @@ -754,6 +760,9 @@ class GenerationMixin: ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length) ) processors = self._merge_criteria_processor_list(processors, logits_processor) + # `LogitNormalization` should always be the last logit processor, when present + if renormalize_logits is True: + processors.append(LogitNormalization()) return processors def _get_stopping_criteria( @@ -858,6 +867,7 @@ class GenerationMixin: diversity_penalty: Optional[float] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), + renormalize_logits: Optional[bool] = None, stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), constraints: Optional[List[Constraint]] = None, output_attentions: Optional[bool] = None, @@ -986,6 +996,10 @@ class GenerationMixin: Custom logits processors that complement the default logits processors built from arguments and a model's config. If a logit processor is passed that is already created with the arguments or a model's config an error is thrown. This feature is intended for advanced users. + renormalize_logits: (`bool`, *optional*, defaults to `False`): + Whether to renormalize the logits after applying all the logits processors or warpers (including the + custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the + score logits are normalized but some logit processors or warpers break the normalization. stopping_criteria (`StoppingCriteriaList`, *optional*): Custom stopping criteria that complement the default stopping criteria built from arguments and a model's config. If a stopping criteria is passed that is already created with the arguments or a @@ -1241,6 +1255,7 @@ class GenerationMixin: remove_invalid_values=remove_invalid_values, exponential_decay_length_penalty=exponential_decay_length_penalty, logits_processor=logits_processor, + renormalize_logits=renormalize_logits, ) # 8. prepare stopping criteria @@ -1271,7 +1286,12 @@ class GenerationMixin: elif is_sample_gen_mode: # 10. prepare logits warper logits_warper = self._get_logits_warper( - top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams + top_k=top_k, + top_p=top_p, + typical_p=typical_p, + temperature=temperature, + num_beams=num_beams, + renormalize_logits=renormalize_logits, ) # 11. expand input_ids with `num_return_sequences` additional sequences per batch @@ -1333,7 +1353,12 @@ class GenerationMixin: elif is_beam_sample_gen_mode: # 10. prepare logits warper logits_warper = self._get_logits_warper( - top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams + top_k=top_k, + top_p=top_p, + typical_p=typical_p, + temperature=temperature, + num_beams=num_beams, + renormalize_logits=renormalize_logits, ) if stopping_criteria.max_length is None: diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 593efa694e..642b13c580 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1400,6 +1400,7 @@ class RagTokenForGeneration(RagPreTrainedModel): n_docs: Optional[int] = None, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None, logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), + renormalize_logits: Optional[bool] = None, stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), forced_bos_token_id: Optional[int] = None, forced_eos_token_id: Optional[int] = None, @@ -1624,6 +1625,7 @@ class RagTokenForGeneration(RagPreTrainedModel): remove_invalid_values=remove_invalid_values, exponential_decay_length_penalty=exponential_decay_length_penalty, logits_processor=logits_processor, + renormalize_logits=renormalize_logits, ) if num_beams == 1: diff --git a/tests/generation/test_generation_logits_process.py b/tests/generation/test_generation_logits_process.py index 8489efd44d..7a515d3e92 100644 --- a/tests/generation/test_generation_logits_process.py +++ b/tests/generation/test_generation_logits_process.py @@ -33,6 +33,7 @@ if is_torch_available(): ForcedEOSTokenLogitsProcessor, HammingDiversityLogitsProcessor, InfNanRemoveLogitsProcessor, + LogitNormalization, LogitsProcessorList, MinLengthLogitsProcessor, NoBadWordsLogitsProcessor, @@ -537,3 +538,18 @@ class LogitsProcessorTest(unittest.TestCase): scores_after_start[penalty_start + 1 :, eos_token_id], scores[penalty_start + 1 :, eos_token_id] ).all() ) + + def test_normalization(self): + input_ids = None + + scores = torch.tensor( + [[-23.18, -29.96, -43.54, 47.77], [-33.58, -26.87, -32.96, 22.51]], device=torch_device, dtype=torch.float + ) + + logit_normalization = LogitNormalization() + normalized_scores = logit_normalization(input_ids, scores).exp() + + ones = torch.ones(scores.shape[0], device=torch_device, dtype=torch.float) + self.assertTrue(normalized_scores.sum(dim=-1).allclose(ones)) + + self.assertTrue(normalized_scores.allclose(scores.softmax(dim=-1)))