From 9442b3ce316878cf24d59905184f47c315d3f083 Mon Sep 17 00:00:00 2001 From: Kevin Bondzio Date: Fri, 11 Mar 2022 19:36:44 +0100 Subject: [PATCH] Add soft length regulation for sequence generation (#15245) * add possibility to softly regulate length when using sampling method in model.generate() function * fix test config, fix formatting * fix rag integration, fix docstyling * fix wrong docstring * change param to tuple, add test * fix old param in rag_model, remove unused import * change test according to new param * fix formatting * fix test case * fix doc style * move start_length calculation to Logitprocessor * add possibility to softly regulate length when using sampling method in model.generate() function * fix rag integration, fix docstyling * fix test config, fix formatting * change param to tuple, add test * fix old param in rag_model, remove unused import * add possibility to softly regulate length when using sampling method in model.generate() function * change param to tuple, add test * fix old param in rag_model, remove unused import * remove unused import * fix small errors * fix test * add possibility to softly regulate length when using sampling method in model.generate() function * fix test config, fix formatting * fix rag integration, fix docstyling * change param to tuple, add test * fix old param in rag_model, remove unused import * change test according to new param * fix test case * move start_length calculation to Logitprocessor * add possibility to softly regulate length when using sampling method in model.generate() function * fix rag integration, fix docstyling * fix test config, fix formatting * change param to tuple, add test * fix old param in rag_model, remove unused import * add possibility to softly regulate length when using sampling method in model.generate() function * fix test config, fix formatting * fix rag integration, fix docstyling * add possibility to softly regulate length when using sampling method in model.generate() function * fix rag integration, fix docstyling * change param to tuple, add test * fix old param in rag_model, remove unused import * fix small errors * Update src/transformers/generation_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/generation_utils.py * Update src/transformers/generation_utils.py * fix docstring, add type ind model rag * fix docstrings * introduce seq_length variable for cleaner code * fix black formatting * add input_ids_seq_length to modeling_rag * add input_ids_seq_length to test * retrigger checks * retrigger checks Co-authored-by: Kevin Bondzio Co-authored-by: Patrick von Platen Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Kevin Bondzio --- src/transformers/configuration_utils.py | 1 + src/transformers/generation_logits_process.py | 31 ++++++++++++++++- src/transformers/generation_utils.py | 28 ++++++++++++++-- src/transformers/models/rag/modeling_rag.py | 11 ++++++- .../test_generation_logits_process.py | 33 +++++++++++++++++++ tests/test_configuration_common.py | 1 + 6 files changed, 100 insertions(+), 5 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index ef42c35c6e..afc3f8f114 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -295,6 +295,7 @@ class PretrainedConfig(PushToHubMixin): self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None) self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None) self.remove_invalid_values = kwargs.pop("remove_invalid_values", False) + self.exponential_decay_length_penalty = kwargs.pop("exponential_decay_length_penalty", None) # Fine-tuning task arguments self.architectures = kwargs.pop("architectures", None) diff --git a/src/transformers/generation_logits_process.py b/src/transformers/generation_logits_process.py index 18f8c5971f..57b62a0354 100644 --- a/src/transformers/generation_logits_process.py +++ b/src/transformers/generation_logits_process.py @@ -15,7 +15,7 @@ import inspect import math -from typing import Callable, Iterable, List, Optional +from typing import Callable, Iterable, List, Optional, Tuple import numpy as np import torch @@ -647,3 +647,32 @@ class InfNanRemoveLogitsProcessor(LogitsProcessor): scores[scores == float("inf")] = torch.finfo(scores.dtype).max return scores + + +class ExponentialDecayLengthPenalty(LogitsProcessor): + r""" + [`LogitsProcessor`] that exponentially increases the score of the eos_token_id after regulation_start has been + reached. + + Args: + exponential_decay_length_penalty (`tuple(int, float)`, *optional*): + This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty + starts and `decay_factor` represents the factor of exponential decay + eos_token_id (`int`): + The id of the *end-of-sequence* token. + input_ids_seq_length (`int`): + The length of the input sequence. + """ + + def __init__(self, exponential_decay_length_penalty: Tuple, eos_token_id: int, input_ids_seq_length: int): + self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length + self.regulation_factor = exponential_decay_length_penalty[1] + self.eos_token_id = eos_token_id + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor: + cur_len = input_ids.shape[-1] + if cur_len > self.regulation_start: + scores[:, self.eos_token_id] = scores[:, self.eos_token_id] * pow( + self.regulation_factor, cur_len - self.regulation_start + ) + return scores diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 85bbc51e6f..62f37ad624 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -28,6 +28,7 @@ from .generation_beam_constraints import Constraint, DisjunctiveConstraint, Phra from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .generation_logits_process import ( EncoderNoRepeatNGramLogitsProcessor, + ExponentialDecayLengthPenalty, ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, HammingDiversityLogitsProcessor, @@ -667,6 +668,7 @@ class GenerationMixin: repetition_penalty: float, no_repeat_ngram_size: int, encoder_no_repeat_ngram_size: int, + input_ids_seq_length: int, encoder_input_ids: torch.LongTensor, bad_words_ids: List[List[int]], min_length: int, @@ -679,6 +681,7 @@ class GenerationMixin: num_beam_groups: int, diversity_penalty: float, remove_invalid_values: bool, + exponential_decay_length_penalty: Tuple, logits_processor: Optional[LogitsProcessorList], ) -> LogitsProcessorList: """ @@ -710,6 +713,11 @@ class GenerationMixin: remove_invalid_values = ( remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values ) + exponential_decay_length_penalty = ( + exponential_decay_length_penalty + if exponential_decay_length_penalty is not None + else self.config.exponential_decay_length_penalty + ) # instantiate processors list # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files @@ -743,6 +751,10 @@ class GenerationMixin: processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)) if remove_invalid_values is True: processors.append(InfNanRemoveLogitsProcessor()) + if exponential_decay_length_penalty is not None: + processors.append( + ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length) + ) processors = self._merge_criteria_processor_list(processors, logits_processor) return processors @@ -858,6 +870,7 @@ class GenerationMixin: forced_eos_token_id: Optional[int] = None, remove_invalid_values: Optional[bool] = None, synced_gpus: Optional[bool] = False, + exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None, **model_kwargs, ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: r""" @@ -1003,6 +1016,11 @@ class GenerationMixin: crash. Note that using `remove_invalid_values` can slow down generation. synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + exponential_decay_length_penalty (`tuple(int, float)`, *optional*): + This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been + generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates + where penalty starts and `decay_factor` represents the factor of exponential decay + model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs @@ -1152,10 +1170,12 @@ class GenerationMixin: # if decoder-only then inputs_tensor has to be `input_ids` input_ids = inputs_tensor + input_ids_seq_length = input_ids.shape[-1] + # 5. Prepare `max_length` depending on other stopping criteria # if `max_new_tokens` is passed, but not `max_length` -> set `max_length = max_new_tokens` if max_length is None and max_new_tokens is not None: - max_length = max_new_tokens + input_ids.shape[-1] + max_length = max_new_tokens + input_ids_seq_length elif max_length is not None and max_new_tokens is not None: # Both are set, this is odd, raise a warning warnings.warn( @@ -1167,10 +1187,10 @@ class GenerationMixin: # default to config if still None max_length = max_length if max_length is not None else self.config.max_length - if input_ids.shape[-1] >= max_length: + if input_ids_seq_length >= max_length: input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" logger.warning( - f"Input length of {input_ids_string} is {input_ids.shape[-1]}, but ``max_length`` is set to {max_length}. " + f"Input length of {input_ids_string} is {input_ids_seq_length}, but ``max_length`` is set to {max_length}. " "This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``." ) @@ -1202,6 +1222,7 @@ class GenerationMixin: repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, + input_ids_seq_length=input_ids_seq_length, encoder_input_ids=inputs_tensor, bad_words_ids=bad_words_ids, min_length=min_length, @@ -1214,6 +1235,7 @@ class GenerationMixin: num_beam_groups=num_beam_groups, diversity_penalty=diversity_penalty, remove_invalid_values=remove_invalid_values, + exponential_decay_length_penalty=exponential_decay_length_penalty, logits_processor=logits_processor, ) diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index dc2de04b01..480bce4037 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -15,7 +15,7 @@ """RAG model implementation.""" from dataclasses import dataclass -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Optional, Tuple, Union import torch from torch import nn @@ -1405,6 +1405,7 @@ class RagTokenForGeneration(RagPreTrainedModel): forced_bos_token_id: Optional[int] = None, forced_eos_token_id: Optional[int] = None, remove_invalid_values: Optional[bool] = None, + exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None, **model_kwargs ): """ @@ -1534,6 +1535,11 @@ class RagTokenForGeneration(RagPreTrainedModel): remove_invalid_values = ( remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values ) + exponential_decay_length_penalty = ( + exponential_decay_length_penalty + if exponential_decay_length_penalty is not None + else self.config.exponential_decay_length_penalty + ) # retrieve docs if self.retriever is not None and context_input_ids is None: @@ -1577,6 +1583,7 @@ class RagTokenForGeneration(RagPreTrainedModel): dtype=torch.long, device=next(self.parameters()).device, ) + input_ids_seq_length = input_ids.shape[-1] last_hidden_state = encoder_outputs["last_hidden_state"] def extend_enc_output(tensor, num_beams=None): @@ -1603,6 +1610,7 @@ class RagTokenForGeneration(RagPreTrainedModel): repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, + input_ids_seq_length=input_ids_seq_length, encoder_input_ids=context_input_ids, bad_words_ids=bad_words_ids, min_length=min_length, @@ -1615,6 +1623,7 @@ class RagTokenForGeneration(RagPreTrainedModel): num_beam_groups=num_beam_groups, diversity_penalty=diversity_penalty, remove_invalid_values=remove_invalid_values, + exponential_decay_length_penalty=exponential_decay_length_penalty, logits_processor=logits_processor, ) diff --git a/tests/generation/test_generation_logits_process.py b/tests/generation/test_generation_logits_process.py index 5ffc6843a1..b95110d0e0 100644 --- a/tests/generation/test_generation_logits_process.py +++ b/tests/generation/test_generation_logits_process.py @@ -28,6 +28,7 @@ if is_torch_available(): from transformers.generation_logits_process import ( EncoderNoRepeatNGramLogitsProcessor, + ExponentialDecayLengthPenalty, ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, HammingDiversityLogitsProcessor, @@ -504,3 +505,35 @@ class LogitsProcessorTest(unittest.TestCase): atol=1e-6, ) ) + + def test_exponential_decay_length_penalty(self): + vocab_size = 20 + batch_size = 4 + eos_token_id = 0 + + penalty_start = 5 + penalty_factor = 1.1 + + input_ids = ids_tensor((batch_size, 2), vocab_size=vocab_size) + input_ids_seq_length = input_ids.shape[-1] + + length_decay_processor = ExponentialDecayLengthPenalty( + exponential_decay_length_penalty=(penalty_start, penalty_factor), + eos_token_id=eos_token_id, + input_ids_seq_length=input_ids_seq_length, + ) + + # check that penalty is not applied before start + scores = self._get_uniform_logits(batch_size, vocab_size) + scores_before_start = length_decay_processor(input_ids, scores) + self.assertListEqual(scores_before_start[:, eos_token_id].tolist(), scores[:, eos_token_id].tolist()) + + # check that penalty is applied after start + input_ids = ids_tensor((batch_size, 20), vocab_size=vocab_size) + scores = self._get_uniform_logits(batch_size, vocab_size) + scores_after_start = length_decay_processor(input_ids, scores) + self.assertTrue( + torch.gt( + scores_after_start[penalty_start + 1 :, eos_token_id], scores[penalty_start + 1 :, eos_token_id] + ).all() + ) diff --git a/tests/test_configuration_common.py b/tests/test_configuration_common.py index a073c52507..08523de9e3 100644 --- a/tests/test_configuration_common.py +++ b/tests/test_configuration_common.py @@ -82,6 +82,7 @@ config_common_kwargs = { "eos_token_id": 8, "sep_token_id": 9, "decoder_start_token_id": 10, + "exponential_decay_length_penalty": (5, 1.01), "task_specific_params": {"translation": "some_params"}, "problem_type": "regression", }