Add soft length regulation for sequence generation (#15245)
* add possibility to softly regulate length when using sampling method in model.generate() function * fix test config, fix formatting * fix rag integration, fix docstyling * fix wrong docstring * change param to tuple, add test * fix old param in rag_model, remove unused import * change test according to new param * fix formatting * fix test case * fix doc style * move start_length calculation to Logitprocessor * add possibility to softly regulate length when using sampling method in model.generate() function * fix rag integration, fix docstyling * fix test config, fix formatting * change param to tuple, add test * fix old param in rag_model, remove unused import * add possibility to softly regulate length when using sampling method in model.generate() function * change param to tuple, add test * fix old param in rag_model, remove unused import * remove unused import * fix small errors * fix test * add possibility to softly regulate length when using sampling method in model.generate() function * fix test config, fix formatting * fix rag integration, fix docstyling * change param to tuple, add test * fix old param in rag_model, remove unused import * change test according to new param * fix test case * move start_length calculation to Logitprocessor * add possibility to softly regulate length when using sampling method in model.generate() function * fix rag integration, fix docstyling * fix test config, fix formatting * change param to tuple, add test * fix old param in rag_model, remove unused import * add possibility to softly regulate length when using sampling method in model.generate() function * fix test config, fix formatting * fix rag integration, fix docstyling * add possibility to softly regulate length when using sampling method in model.generate() function * fix rag integration, fix docstyling * change param to tuple, add test * fix old param in rag_model, remove unused import * fix small errors * Update src/transformers/generation_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/generation_utils.py * Update src/transformers/generation_utils.py * fix docstring, add type ind model rag * fix docstrings * introduce seq_length variable for cleaner code * fix black formatting * add input_ids_seq_length to modeling_rag * add input_ids_seq_length to test * retrigger checks * retrigger checks Co-authored-by: Kevin Bondzio <kev@AIM-LAP-02.local> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Kevin Bondzio <kev@AIM-LAP-02.fritz.box>
This commit is contained in:
@@ -295,6 +295,7 @@ class PretrainedConfig(PushToHubMixin):
|
||||
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
|
||||
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
|
||||
self.remove_invalid_values = kwargs.pop("remove_invalid_values", False)
|
||||
self.exponential_decay_length_penalty = kwargs.pop("exponential_decay_length_penalty", None)
|
||||
|
||||
# Fine-tuning task arguments
|
||||
self.architectures = kwargs.pop("architectures", None)
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from typing import Callable, Iterable, List, Optional
|
||||
from typing import Callable, Iterable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -647,3 +647,32 @@ class InfNanRemoveLogitsProcessor(LogitsProcessor):
|
||||
scores[scores == float("inf")] = torch.finfo(scores.dtype).max
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
class ExponentialDecayLengthPenalty(LogitsProcessor):
|
||||
r"""
|
||||
[`LogitsProcessor`] that exponentially increases the score of the eos_token_id after regulation_start has been
|
||||
reached.
|
||||
|
||||
Args:
|
||||
exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
|
||||
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty
|
||||
starts and `decay_factor` represents the factor of exponential decay
|
||||
eos_token_id (`int`):
|
||||
The id of the *end-of-sequence* token.
|
||||
input_ids_seq_length (`int`):
|
||||
The length of the input sequence.
|
||||
"""
|
||||
|
||||
def __init__(self, exponential_decay_length_penalty: Tuple, eos_token_id: int, input_ids_seq_length: int):
|
||||
self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
|
||||
self.regulation_factor = exponential_decay_length_penalty[1]
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
|
||||
cur_len = input_ids.shape[-1]
|
||||
if cur_len > self.regulation_start:
|
||||
scores[:, self.eos_token_id] = scores[:, self.eos_token_id] * pow(
|
||||
self.regulation_factor, cur_len - self.regulation_start
|
||||
)
|
||||
return scores
|
||||
|
||||
@@ -28,6 +28,7 @@ from .generation_beam_constraints import Constraint, DisjunctiveConstraint, Phra
|
||||
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||
from .generation_logits_process import (
|
||||
EncoderNoRepeatNGramLogitsProcessor,
|
||||
ExponentialDecayLengthPenalty,
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
HammingDiversityLogitsProcessor,
|
||||
@@ -667,6 +668,7 @@ class GenerationMixin:
|
||||
repetition_penalty: float,
|
||||
no_repeat_ngram_size: int,
|
||||
encoder_no_repeat_ngram_size: int,
|
||||
input_ids_seq_length: int,
|
||||
encoder_input_ids: torch.LongTensor,
|
||||
bad_words_ids: List[List[int]],
|
||||
min_length: int,
|
||||
@@ -679,6 +681,7 @@ class GenerationMixin:
|
||||
num_beam_groups: int,
|
||||
diversity_penalty: float,
|
||||
remove_invalid_values: bool,
|
||||
exponential_decay_length_penalty: Tuple,
|
||||
logits_processor: Optional[LogitsProcessorList],
|
||||
) -> LogitsProcessorList:
|
||||
"""
|
||||
@@ -710,6 +713,11 @@ class GenerationMixin:
|
||||
remove_invalid_values = (
|
||||
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
|
||||
)
|
||||
exponential_decay_length_penalty = (
|
||||
exponential_decay_length_penalty
|
||||
if exponential_decay_length_penalty is not None
|
||||
else self.config.exponential_decay_length_penalty
|
||||
)
|
||||
# instantiate processors list
|
||||
|
||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||
@@ -743,6 +751,10 @@ class GenerationMixin:
|
||||
processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
|
||||
if remove_invalid_values is True:
|
||||
processors.append(InfNanRemoveLogitsProcessor())
|
||||
if exponential_decay_length_penalty is not None:
|
||||
processors.append(
|
||||
ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length)
|
||||
)
|
||||
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
||||
return processors
|
||||
|
||||
@@ -858,6 +870,7 @@ class GenerationMixin:
|
||||
forced_eos_token_id: Optional[int] = None,
|
||||
remove_invalid_values: Optional[bool] = None,
|
||||
synced_gpus: Optional[bool] = False,
|
||||
exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None,
|
||||
**model_kwargs,
|
||||
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
|
||||
r"""
|
||||
@@ -1003,6 +1016,11 @@ class GenerationMixin:
|
||||
crash. Note that using `remove_invalid_values` can slow down generation.
|
||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
|
||||
This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been
|
||||
generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates
|
||||
where penalty starts and `decay_factor` represents the factor of exponential decay
|
||||
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model
|
||||
is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
|
||||
@@ -1152,10 +1170,12 @@ class GenerationMixin:
|
||||
# if decoder-only then inputs_tensor has to be `input_ids`
|
||||
input_ids = inputs_tensor
|
||||
|
||||
input_ids_seq_length = input_ids.shape[-1]
|
||||
|
||||
# 5. Prepare `max_length` depending on other stopping criteria
|
||||
# if `max_new_tokens` is passed, but not `max_length` -> set `max_length = max_new_tokens`
|
||||
if max_length is None and max_new_tokens is not None:
|
||||
max_length = max_new_tokens + input_ids.shape[-1]
|
||||
max_length = max_new_tokens + input_ids_seq_length
|
||||
elif max_length is not None and max_new_tokens is not None:
|
||||
# Both are set, this is odd, raise a warning
|
||||
warnings.warn(
|
||||
@@ -1167,10 +1187,10 @@ class GenerationMixin:
|
||||
# default to config if still None
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
|
||||
if input_ids.shape[-1] >= max_length:
|
||||
if input_ids_seq_length >= max_length:
|
||||
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
||||
logger.warning(
|
||||
f"Input length of {input_ids_string} is {input_ids.shape[-1]}, but ``max_length`` is set to {max_length}. "
|
||||
f"Input length of {input_ids_string} is {input_ids_seq_length}, but ``max_length`` is set to {max_length}. "
|
||||
"This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``."
|
||||
)
|
||||
|
||||
@@ -1202,6 +1222,7 @@ class GenerationMixin:
|
||||
repetition_penalty=repetition_penalty,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
encoder_input_ids=inputs_tensor,
|
||||
bad_words_ids=bad_words_ids,
|
||||
min_length=min_length,
|
||||
@@ -1214,6 +1235,7 @@ class GenerationMixin:
|
||||
num_beam_groups=num_beam_groups,
|
||||
diversity_penalty=diversity_penalty,
|
||||
remove_invalid_values=remove_invalid_values,
|
||||
exponential_decay_length_penalty=exponential_decay_length_penalty,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
"""RAG model implementation."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -1405,6 +1405,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
forced_bos_token_id: Optional[int] = None,
|
||||
forced_eos_token_id: Optional[int] = None,
|
||||
remove_invalid_values: Optional[bool] = None,
|
||||
exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None,
|
||||
**model_kwargs
|
||||
):
|
||||
"""
|
||||
@@ -1534,6 +1535,11 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
remove_invalid_values = (
|
||||
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
|
||||
)
|
||||
exponential_decay_length_penalty = (
|
||||
exponential_decay_length_penalty
|
||||
if exponential_decay_length_penalty is not None
|
||||
else self.config.exponential_decay_length_penalty
|
||||
)
|
||||
|
||||
# retrieve docs
|
||||
if self.retriever is not None and context_input_ids is None:
|
||||
@@ -1577,6 +1583,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
dtype=torch.long,
|
||||
device=next(self.parameters()).device,
|
||||
)
|
||||
input_ids_seq_length = input_ids.shape[-1]
|
||||
last_hidden_state = encoder_outputs["last_hidden_state"]
|
||||
|
||||
def extend_enc_output(tensor, num_beams=None):
|
||||
@@ -1603,6 +1610,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
repetition_penalty=repetition_penalty,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
encoder_input_ids=context_input_ids,
|
||||
bad_words_ids=bad_words_ids,
|
||||
min_length=min_length,
|
||||
@@ -1615,6 +1623,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
num_beam_groups=num_beam_groups,
|
||||
diversity_penalty=diversity_penalty,
|
||||
remove_invalid_values=remove_invalid_values,
|
||||
exponential_decay_length_penalty=exponential_decay_length_penalty,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ if is_torch_available():
|
||||
|
||||
from transformers.generation_logits_process import (
|
||||
EncoderNoRepeatNGramLogitsProcessor,
|
||||
ExponentialDecayLengthPenalty,
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
HammingDiversityLogitsProcessor,
|
||||
@@ -504,3 +505,35 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
atol=1e-6,
|
||||
)
|
||||
)
|
||||
|
||||
def test_exponential_decay_length_penalty(self):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
eos_token_id = 0
|
||||
|
||||
penalty_start = 5
|
||||
penalty_factor = 1.1
|
||||
|
||||
input_ids = ids_tensor((batch_size, 2), vocab_size=vocab_size)
|
||||
input_ids_seq_length = input_ids.shape[-1]
|
||||
|
||||
length_decay_processor = ExponentialDecayLengthPenalty(
|
||||
exponential_decay_length_penalty=(penalty_start, penalty_factor),
|
||||
eos_token_id=eos_token_id,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
)
|
||||
|
||||
# check that penalty is not applied before start
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores_before_start = length_decay_processor(input_ids, scores)
|
||||
self.assertListEqual(scores_before_start[:, eos_token_id].tolist(), scores[:, eos_token_id].tolist())
|
||||
|
||||
# check that penalty is applied after start
|
||||
input_ids = ids_tensor((batch_size, 20), vocab_size=vocab_size)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores_after_start = length_decay_processor(input_ids, scores)
|
||||
self.assertTrue(
|
||||
torch.gt(
|
||||
scores_after_start[penalty_start + 1 :, eos_token_id], scores[penalty_start + 1 :, eos_token_id]
|
||||
).all()
|
||||
)
|
||||
|
||||
@@ -82,6 +82,7 @@ config_common_kwargs = {
|
||||
"eos_token_id": 8,
|
||||
"sep_token_id": 9,
|
||||
"decoder_start_token_id": 10,
|
||||
"exponential_decay_length_penalty": (5, 1.01),
|
||||
"task_specific_params": {"translation": "some_params"},
|
||||
"problem_type": "regression",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user