Fix decoding score comparison when using logits processors or warpers (#10638)
* Normalize using a logits warper * Add a flag in `generate` to support the logit renormalization * Add in RAG
This commit is contained in:
@@ -679,3 +679,16 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
|
|||||||
self.regulation_factor, cur_len - self.regulation_start
|
self.regulation_factor, cur_len - self.regulation_start
|
||||||
)
|
)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class LogitNormalization(LogitsProcessor, LogitsWarper):
|
||||||
|
r"""
|
||||||
|
[`LogitsWarper`] and [`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize
|
||||||
|
the scores during beam search, after applying the logits processors or warpers, since the search algorithm used in
|
||||||
|
this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that
|
||||||
|
the scores are normalized when comparing the hypotheses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
|
scores = scores.log_softmax(dim=-1)
|
||||||
|
return scores
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from .generation_logits_process import (
|
|||||||
ForcedEOSTokenLogitsProcessor,
|
ForcedEOSTokenLogitsProcessor,
|
||||||
HammingDiversityLogitsProcessor,
|
HammingDiversityLogitsProcessor,
|
||||||
InfNanRemoveLogitsProcessor,
|
InfNanRemoveLogitsProcessor,
|
||||||
|
LogitNormalization,
|
||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
MinLengthLogitsProcessor,
|
MinLengthLogitsProcessor,
|
||||||
NoBadWordsLogitsProcessor,
|
NoBadWordsLogitsProcessor,
|
||||||
@@ -636,6 +637,7 @@ class GenerationMixin:
|
|||||||
typical_p: Optional[float] = None,
|
typical_p: Optional[float] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
num_beams: Optional[int] = None,
|
num_beams: Optional[int] = None,
|
||||||
|
renormalize_logits: Optional[bool] = None,
|
||||||
) -> LogitsProcessorList:
|
) -> LogitsProcessorList:
|
||||||
"""
|
"""
|
||||||
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
|
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
|
||||||
@@ -660,6 +662,9 @@ class GenerationMixin:
|
|||||||
warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
|
warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
|
||||||
if typical_p is not None and typical_p < 1.0:
|
if typical_p is not None and typical_p < 1.0:
|
||||||
warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
|
warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
|
||||||
|
# `LogitNormalization` should always be the last logit processor, when present
|
||||||
|
if renormalize_logits is True:
|
||||||
|
warpers.append(LogitNormalization())
|
||||||
return warpers
|
return warpers
|
||||||
|
|
||||||
def _get_logits_processor(
|
def _get_logits_processor(
|
||||||
@@ -682,6 +687,7 @@ class GenerationMixin:
|
|||||||
remove_invalid_values: bool,
|
remove_invalid_values: bool,
|
||||||
exponential_decay_length_penalty: Tuple,
|
exponential_decay_length_penalty: Tuple,
|
||||||
logits_processor: Optional[LogitsProcessorList],
|
logits_processor: Optional[LogitsProcessorList],
|
||||||
|
renormalize_logits: Optional[bool],
|
||||||
) -> LogitsProcessorList:
|
) -> LogitsProcessorList:
|
||||||
"""
|
"""
|
||||||
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
|
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
|
||||||
@@ -754,6 +760,9 @@ class GenerationMixin:
|
|||||||
ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length)
|
ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length)
|
||||||
)
|
)
|
||||||
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
||||||
|
# `LogitNormalization` should always be the last logit processor, when present
|
||||||
|
if renormalize_logits is True:
|
||||||
|
processors.append(LogitNormalization())
|
||||||
return processors
|
return processors
|
||||||
|
|
||||||
def _get_stopping_criteria(
|
def _get_stopping_criteria(
|
||||||
@@ -858,6 +867,7 @@ class GenerationMixin:
|
|||||||
diversity_penalty: Optional[float] = None,
|
diversity_penalty: Optional[float] = None,
|
||||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||||
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
|
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
|
||||||
|
renormalize_logits: Optional[bool] = None,
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
|
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
|
||||||
constraints: Optional[List[Constraint]] = None,
|
constraints: Optional[List[Constraint]] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
@@ -986,6 +996,10 @@ class GenerationMixin:
|
|||||||
Custom logits processors that complement the default logits processors built from arguments and a
|
Custom logits processors that complement the default logits processors built from arguments and a
|
||||||
model's config. If a logit processor is passed that is already created with the arguments or a model's
|
model's config. If a logit processor is passed that is already created with the arguments or a model's
|
||||||
config an error is thrown. This feature is intended for advanced users.
|
config an error is thrown. This feature is intended for advanced users.
|
||||||
|
renormalize_logits: (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to renormalize the logits after applying all the logits processors or warpers (including the
|
||||||
|
custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the
|
||||||
|
score logits are normalized but some logit processors or warpers break the normalization.
|
||||||
stopping_criteria (`StoppingCriteriaList`, *optional*):
|
stopping_criteria (`StoppingCriteriaList`, *optional*):
|
||||||
Custom stopping criteria that complement the default stopping criteria built from arguments and a
|
Custom stopping criteria that complement the default stopping criteria built from arguments and a
|
||||||
model's config. If a stopping criteria is passed that is already created with the arguments or a
|
model's config. If a stopping criteria is passed that is already created with the arguments or a
|
||||||
@@ -1241,6 +1255,7 @@ class GenerationMixin:
|
|||||||
remove_invalid_values=remove_invalid_values,
|
remove_invalid_values=remove_invalid_values,
|
||||||
exponential_decay_length_penalty=exponential_decay_length_penalty,
|
exponential_decay_length_penalty=exponential_decay_length_penalty,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
|
renormalize_logits=renormalize_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 8. prepare stopping criteria
|
# 8. prepare stopping criteria
|
||||||
@@ -1271,7 +1286,12 @@ class GenerationMixin:
|
|||||||
elif is_sample_gen_mode:
|
elif is_sample_gen_mode:
|
||||||
# 10. prepare logits warper
|
# 10. prepare logits warper
|
||||||
logits_warper = self._get_logits_warper(
|
logits_warper = self._get_logits_warper(
|
||||||
top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
typical_p=typical_p,
|
||||||
|
temperature=temperature,
|
||||||
|
num_beams=num_beams,
|
||||||
|
renormalize_logits=renormalize_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
|
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
|
||||||
@@ -1333,7 +1353,12 @@ class GenerationMixin:
|
|||||||
elif is_beam_sample_gen_mode:
|
elif is_beam_sample_gen_mode:
|
||||||
# 10. prepare logits warper
|
# 10. prepare logits warper
|
||||||
logits_warper = self._get_logits_warper(
|
logits_warper = self._get_logits_warper(
|
||||||
top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
typical_p=typical_p,
|
||||||
|
temperature=temperature,
|
||||||
|
num_beams=num_beams,
|
||||||
|
renormalize_logits=renormalize_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stopping_criteria.max_length is None:
|
if stopping_criteria.max_length is None:
|
||||||
|
|||||||
@@ -1400,6 +1400,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
n_docs: Optional[int] = None,
|
n_docs: Optional[int] = None,
|
||||||
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
|
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
|
||||||
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
|
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
|
||||||
|
renormalize_logits: Optional[bool] = None,
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
|
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
|
||||||
forced_bos_token_id: Optional[int] = None,
|
forced_bos_token_id: Optional[int] = None,
|
||||||
forced_eos_token_id: Optional[int] = None,
|
forced_eos_token_id: Optional[int] = None,
|
||||||
@@ -1624,6 +1625,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
remove_invalid_values=remove_invalid_values,
|
remove_invalid_values=remove_invalid_values,
|
||||||
exponential_decay_length_penalty=exponential_decay_length_penalty,
|
exponential_decay_length_penalty=exponential_decay_length_penalty,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
|
renormalize_logits=renormalize_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
if num_beams == 1:
|
if num_beams == 1:
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ if is_torch_available():
|
|||||||
ForcedEOSTokenLogitsProcessor,
|
ForcedEOSTokenLogitsProcessor,
|
||||||
HammingDiversityLogitsProcessor,
|
HammingDiversityLogitsProcessor,
|
||||||
InfNanRemoveLogitsProcessor,
|
InfNanRemoveLogitsProcessor,
|
||||||
|
LogitNormalization,
|
||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
MinLengthLogitsProcessor,
|
MinLengthLogitsProcessor,
|
||||||
NoBadWordsLogitsProcessor,
|
NoBadWordsLogitsProcessor,
|
||||||
@@ -537,3 +538,18 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
scores_after_start[penalty_start + 1 :, eos_token_id], scores[penalty_start + 1 :, eos_token_id]
|
scores_after_start[penalty_start + 1 :, eos_token_id], scores[penalty_start + 1 :, eos_token_id]
|
||||||
).all()
|
).all()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_normalization(self):
|
||||||
|
input_ids = None
|
||||||
|
|
||||||
|
scores = torch.tensor(
|
||||||
|
[[-23.18, -29.96, -43.54, 47.77], [-33.58, -26.87, -32.96, 22.51]], device=torch_device, dtype=torch.float
|
||||||
|
)
|
||||||
|
|
||||||
|
logit_normalization = LogitNormalization()
|
||||||
|
normalized_scores = logit_normalization(input_ids, scores).exp()
|
||||||
|
|
||||||
|
ones = torch.ones(scores.shape[0], device=torch_device, dtype=torch.float)
|
||||||
|
self.assertTrue(normalized_scores.sum(dim=-1).allclose(ones))
|
||||||
|
|
||||||
|
self.assertTrue(normalized_scores.allclose(scores.softmax(dim=-1)))
|
||||||
|
|||||||
Reference in New Issue
Block a user