Add soft length regulation for sequence generation (#15245)
* add possibility to softly regulate length when using sampling method in model.generate() function * fix test config, fix formatting * fix rag integration, fix docstyling * fix wrong docstring * change param to tuple, add test * fix old param in rag_model, remove unused import * change test according to new param * fix formatting * fix test case * fix doc style * move start_length calculation to Logitprocessor * add possibility to softly regulate length when using sampling method in model.generate() function * fix rag integration, fix docstyling * fix test config, fix formatting * change param to tuple, add test * fix old param in rag_model, remove unused import * add possibility to softly regulate length when using sampling method in model.generate() function * change param to tuple, add test * fix old param in rag_model, remove unused import * remove unused import * fix small errors * fix test * add possibility to softly regulate length when using sampling method in model.generate() function * fix test config, fix formatting * fix rag integration, fix docstyling * change param to tuple, add test * fix old param in rag_model, remove unused import * change test according to new param * fix test case * move start_length calculation to Logitprocessor * add possibility to softly regulate length when using sampling method in model.generate() function * fix rag integration, fix docstyling * fix test config, fix formatting * change param to tuple, add test * fix old param in rag_model, remove unused import * add possibility to softly regulate length when using sampling method in model.generate() function * fix test config, fix formatting * fix rag integration, fix docstyling * add possibility to softly regulate length when using sampling method in model.generate() function * fix rag integration, fix docstyling * change param to tuple, add test * fix old param in rag_model, remove unused import * fix small errors * Update src/transformers/generation_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/generation_utils.py * Update src/transformers/generation_utils.py * fix docstring, add type ind model rag * fix docstrings * introduce seq_length variable for cleaner code * fix black formatting * add input_ids_seq_length to modeling_rag * add input_ids_seq_length to test * retrigger checks * retrigger checks Co-authored-by: Kevin Bondzio <kev@AIM-LAP-02.local> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Kevin Bondzio <kev@AIM-LAP-02.fritz.box>
This commit is contained in:
@@ -295,6 +295,7 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
|
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
|
||||||
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
|
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
|
||||||
self.remove_invalid_values = kwargs.pop("remove_invalid_values", False)
|
self.remove_invalid_values = kwargs.pop("remove_invalid_values", False)
|
||||||
|
self.exponential_decay_length_penalty = kwargs.pop("exponential_decay_length_penalty", None)
|
||||||
|
|
||||||
# Fine-tuning task arguments
|
# Fine-tuning task arguments
|
||||||
self.architectures = kwargs.pop("architectures", None)
|
self.architectures = kwargs.pop("architectures", None)
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import math
|
import math
|
||||||
from typing import Callable, Iterable, List, Optional
|
from typing import Callable, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -647,3 +647,32 @@ class InfNanRemoveLogitsProcessor(LogitsProcessor):
|
|||||||
scores[scores == float("inf")] = torch.finfo(scores.dtype).max
|
scores[scores == float("inf")] = torch.finfo(scores.dtype).max
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class ExponentialDecayLengthPenalty(LogitsProcessor):
|
||||||
|
r"""
|
||||||
|
[`LogitsProcessor`] that exponentially increases the score of the eos_token_id after regulation_start has been
|
||||||
|
reached.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
|
||||||
|
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty
|
||||||
|
starts and `decay_factor` represents the factor of exponential decay
|
||||||
|
eos_token_id (`int`):
|
||||||
|
The id of the *end-of-sequence* token.
|
||||||
|
input_ids_seq_length (`int`):
|
||||||
|
The length of the input sequence.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, exponential_decay_length_penalty: Tuple, eos_token_id: int, input_ids_seq_length: int):
|
||||||
|
self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
|
||||||
|
self.regulation_factor = exponential_decay_length_penalty[1]
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
|
||||||
|
cur_len = input_ids.shape[-1]
|
||||||
|
if cur_len > self.regulation_start:
|
||||||
|
scores[:, self.eos_token_id] = scores[:, self.eos_token_id] * pow(
|
||||||
|
self.regulation_factor, cur_len - self.regulation_start
|
||||||
|
)
|
||||||
|
return scores
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from .generation_beam_constraints import Constraint, DisjunctiveConstraint, Phra
|
|||||||
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||||
from .generation_logits_process import (
|
from .generation_logits_process import (
|
||||||
EncoderNoRepeatNGramLogitsProcessor,
|
EncoderNoRepeatNGramLogitsProcessor,
|
||||||
|
ExponentialDecayLengthPenalty,
|
||||||
ForcedBOSTokenLogitsProcessor,
|
ForcedBOSTokenLogitsProcessor,
|
||||||
ForcedEOSTokenLogitsProcessor,
|
ForcedEOSTokenLogitsProcessor,
|
||||||
HammingDiversityLogitsProcessor,
|
HammingDiversityLogitsProcessor,
|
||||||
@@ -667,6 +668,7 @@ class GenerationMixin:
|
|||||||
repetition_penalty: float,
|
repetition_penalty: float,
|
||||||
no_repeat_ngram_size: int,
|
no_repeat_ngram_size: int,
|
||||||
encoder_no_repeat_ngram_size: int,
|
encoder_no_repeat_ngram_size: int,
|
||||||
|
input_ids_seq_length: int,
|
||||||
encoder_input_ids: torch.LongTensor,
|
encoder_input_ids: torch.LongTensor,
|
||||||
bad_words_ids: List[List[int]],
|
bad_words_ids: List[List[int]],
|
||||||
min_length: int,
|
min_length: int,
|
||||||
@@ -679,6 +681,7 @@ class GenerationMixin:
|
|||||||
num_beam_groups: int,
|
num_beam_groups: int,
|
||||||
diversity_penalty: float,
|
diversity_penalty: float,
|
||||||
remove_invalid_values: bool,
|
remove_invalid_values: bool,
|
||||||
|
exponential_decay_length_penalty: Tuple,
|
||||||
logits_processor: Optional[LogitsProcessorList],
|
logits_processor: Optional[LogitsProcessorList],
|
||||||
) -> LogitsProcessorList:
|
) -> LogitsProcessorList:
|
||||||
"""
|
"""
|
||||||
@@ -710,6 +713,11 @@ class GenerationMixin:
|
|||||||
remove_invalid_values = (
|
remove_invalid_values = (
|
||||||
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
|
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
|
||||||
)
|
)
|
||||||
|
exponential_decay_length_penalty = (
|
||||||
|
exponential_decay_length_penalty
|
||||||
|
if exponential_decay_length_penalty is not None
|
||||||
|
else self.config.exponential_decay_length_penalty
|
||||||
|
)
|
||||||
# instantiate processors list
|
# instantiate processors list
|
||||||
|
|
||||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||||
@@ -743,6 +751,10 @@ class GenerationMixin:
|
|||||||
processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
|
processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
|
||||||
if remove_invalid_values is True:
|
if remove_invalid_values is True:
|
||||||
processors.append(InfNanRemoveLogitsProcessor())
|
processors.append(InfNanRemoveLogitsProcessor())
|
||||||
|
if exponential_decay_length_penalty is not None:
|
||||||
|
processors.append(
|
||||||
|
ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length)
|
||||||
|
)
|
||||||
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
||||||
return processors
|
return processors
|
||||||
|
|
||||||
@@ -858,6 +870,7 @@ class GenerationMixin:
|
|||||||
forced_eos_token_id: Optional[int] = None,
|
forced_eos_token_id: Optional[int] = None,
|
||||||
remove_invalid_values: Optional[bool] = None,
|
remove_invalid_values: Optional[bool] = None,
|
||||||
synced_gpus: Optional[bool] = False,
|
synced_gpus: Optional[bool] = False,
|
||||||
|
exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
|
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
|
||||||
r"""
|
r"""
|
||||||
@@ -1003,6 +1016,11 @@ class GenerationMixin:
|
|||||||
crash. Note that using `remove_invalid_values` can slow down generation.
|
crash. Note that using `remove_invalid_values` can slow down generation.
|
||||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||||
|
exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
|
||||||
|
This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been
|
||||||
|
generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates
|
||||||
|
where penalty starts and `decay_factor` represents the factor of exponential decay
|
||||||
|
|
||||||
model_kwargs:
|
model_kwargs:
|
||||||
Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model
|
Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model
|
||||||
is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
|
is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
|
||||||
@@ -1152,10 +1170,12 @@ class GenerationMixin:
|
|||||||
# if decoder-only then inputs_tensor has to be `input_ids`
|
# if decoder-only then inputs_tensor has to be `input_ids`
|
||||||
input_ids = inputs_tensor
|
input_ids = inputs_tensor
|
||||||
|
|
||||||
|
input_ids_seq_length = input_ids.shape[-1]
|
||||||
|
|
||||||
# 5. Prepare `max_length` depending on other stopping criteria
|
# 5. Prepare `max_length` depending on other stopping criteria
|
||||||
# if `max_new_tokens` is passed, but not `max_length` -> set `max_length = max_new_tokens`
|
# if `max_new_tokens` is passed, but not `max_length` -> set `max_length = max_new_tokens`
|
||||||
if max_length is None and max_new_tokens is not None:
|
if max_length is None and max_new_tokens is not None:
|
||||||
max_length = max_new_tokens + input_ids.shape[-1]
|
max_length = max_new_tokens + input_ids_seq_length
|
||||||
elif max_length is not None and max_new_tokens is not None:
|
elif max_length is not None and max_new_tokens is not None:
|
||||||
# Both are set, this is odd, raise a warning
|
# Both are set, this is odd, raise a warning
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
@@ -1167,10 +1187,10 @@ class GenerationMixin:
|
|||||||
# default to config if still None
|
# default to config if still None
|
||||||
max_length = max_length if max_length is not None else self.config.max_length
|
max_length = max_length if max_length is not None else self.config.max_length
|
||||||
|
|
||||||
if input_ids.shape[-1] >= max_length:
|
if input_ids_seq_length >= max_length:
|
||||||
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Input length of {input_ids_string} is {input_ids.shape[-1]}, but ``max_length`` is set to {max_length}. "
|
f"Input length of {input_ids_string} is {input_ids_seq_length}, but ``max_length`` is set to {max_length}. "
|
||||||
"This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``."
|
"This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``."
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1202,6 +1222,7 @@ class GenerationMixin:
|
|||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||||
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
|
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
|
||||||
|
input_ids_seq_length=input_ids_seq_length,
|
||||||
encoder_input_ids=inputs_tensor,
|
encoder_input_ids=inputs_tensor,
|
||||||
bad_words_ids=bad_words_ids,
|
bad_words_ids=bad_words_ids,
|
||||||
min_length=min_length,
|
min_length=min_length,
|
||||||
@@ -1214,6 +1235,7 @@ class GenerationMixin:
|
|||||||
num_beam_groups=num_beam_groups,
|
num_beam_groups=num_beam_groups,
|
||||||
diversity_penalty=diversity_penalty,
|
diversity_penalty=diversity_penalty,
|
||||||
remove_invalid_values=remove_invalid_values,
|
remove_invalid_values=remove_invalid_values,
|
||||||
|
exponential_decay_length_penalty=exponential_decay_length_penalty,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
"""RAG model implementation."""
|
"""RAG model implementation."""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Callable, List, Optional, Tuple
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -1405,6 +1405,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
forced_bos_token_id: Optional[int] = None,
|
forced_bos_token_id: Optional[int] = None,
|
||||||
forced_eos_token_id: Optional[int] = None,
|
forced_eos_token_id: Optional[int] = None,
|
||||||
remove_invalid_values: Optional[bool] = None,
|
remove_invalid_values: Optional[bool] = None,
|
||||||
|
exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None,
|
||||||
**model_kwargs
|
**model_kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -1534,6 +1535,11 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
remove_invalid_values = (
|
remove_invalid_values = (
|
||||||
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
|
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
|
||||||
)
|
)
|
||||||
|
exponential_decay_length_penalty = (
|
||||||
|
exponential_decay_length_penalty
|
||||||
|
if exponential_decay_length_penalty is not None
|
||||||
|
else self.config.exponential_decay_length_penalty
|
||||||
|
)
|
||||||
|
|
||||||
# retrieve docs
|
# retrieve docs
|
||||||
if self.retriever is not None and context_input_ids is None:
|
if self.retriever is not None and context_input_ids is None:
|
||||||
@@ -1577,6 +1583,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=next(self.parameters()).device,
|
device=next(self.parameters()).device,
|
||||||
)
|
)
|
||||||
|
input_ids_seq_length = input_ids.shape[-1]
|
||||||
last_hidden_state = encoder_outputs["last_hidden_state"]
|
last_hidden_state = encoder_outputs["last_hidden_state"]
|
||||||
|
|
||||||
def extend_enc_output(tensor, num_beams=None):
|
def extend_enc_output(tensor, num_beams=None):
|
||||||
@@ -1603,6 +1610,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||||
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
|
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
|
||||||
|
input_ids_seq_length=input_ids_seq_length,
|
||||||
encoder_input_ids=context_input_ids,
|
encoder_input_ids=context_input_ids,
|
||||||
bad_words_ids=bad_words_ids,
|
bad_words_ids=bad_words_ids,
|
||||||
min_length=min_length,
|
min_length=min_length,
|
||||||
@@ -1615,6 +1623,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
num_beam_groups=num_beam_groups,
|
num_beam_groups=num_beam_groups,
|
||||||
diversity_penalty=diversity_penalty,
|
diversity_penalty=diversity_penalty,
|
||||||
remove_invalid_values=remove_invalid_values,
|
remove_invalid_values=remove_invalid_values,
|
||||||
|
exponential_decay_length_penalty=exponential_decay_length_penalty,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
from transformers.generation_logits_process import (
|
from transformers.generation_logits_process import (
|
||||||
EncoderNoRepeatNGramLogitsProcessor,
|
EncoderNoRepeatNGramLogitsProcessor,
|
||||||
|
ExponentialDecayLengthPenalty,
|
||||||
ForcedBOSTokenLogitsProcessor,
|
ForcedBOSTokenLogitsProcessor,
|
||||||
ForcedEOSTokenLogitsProcessor,
|
ForcedEOSTokenLogitsProcessor,
|
||||||
HammingDiversityLogitsProcessor,
|
HammingDiversityLogitsProcessor,
|
||||||
@@ -504,3 +505,35 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
atol=1e-6,
|
atol=1e-6,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_exponential_decay_length_penalty(self):
|
||||||
|
vocab_size = 20
|
||||||
|
batch_size = 4
|
||||||
|
eos_token_id = 0
|
||||||
|
|
||||||
|
penalty_start = 5
|
||||||
|
penalty_factor = 1.1
|
||||||
|
|
||||||
|
input_ids = ids_tensor((batch_size, 2), vocab_size=vocab_size)
|
||||||
|
input_ids_seq_length = input_ids.shape[-1]
|
||||||
|
|
||||||
|
length_decay_processor = ExponentialDecayLengthPenalty(
|
||||||
|
exponential_decay_length_penalty=(penalty_start, penalty_factor),
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
input_ids_seq_length=input_ids_seq_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
# check that penalty is not applied before start
|
||||||
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
|
scores_before_start = length_decay_processor(input_ids, scores)
|
||||||
|
self.assertListEqual(scores_before_start[:, eos_token_id].tolist(), scores[:, eos_token_id].tolist())
|
||||||
|
|
||||||
|
# check that penalty is applied after start
|
||||||
|
input_ids = ids_tensor((batch_size, 20), vocab_size=vocab_size)
|
||||||
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
|
scores_after_start = length_decay_processor(input_ids, scores)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.gt(
|
||||||
|
scores_after_start[penalty_start + 1 :, eos_token_id], scores[penalty_start + 1 :, eos_token_id]
|
||||||
|
).all()
|
||||||
|
)
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ config_common_kwargs = {
|
|||||||
"eos_token_id": 8,
|
"eos_token_id": 8,
|
||||||
"sep_token_id": 9,
|
"sep_token_id": 9,
|
||||||
"decoder_start_token_id": 10,
|
"decoder_start_token_id": 10,
|
||||||
|
"exponential_decay_length_penalty": (5, 1.01),
|
||||||
"task_specific_params": {"translation": "some_params"},
|
"task_specific_params": {"translation": "some_params"},
|
||||||
"problem_type": "regression",
|
"problem_type": "regression",
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user