From b9403e951661b53630afd95166874f75ede885c4 Mon Sep 17 00:00:00 2001 From: Karim Foda <35491698+KMFODA@users.noreply.github.com> Date: Thu, 19 Jan 2023 18:20:25 +0200 Subject: [PATCH] Add hallucination filter (#18675) * Add hallucination penalty * Make quality changes * Inverse penalty * Fix imports & quality * Fix name spelling issue * set encoder_repetition_penalty and fix quality * Fix failing test * Add to config_common_kwargs * Fix modelling_rag error * Update src/transformers/generation_logits_process.py Co-authored-by: Joao Gante * Remove breakpoint * Make style fixes * Update encoder_repetition_penalty default value * Merge latest main changes * Make fixup changes * Add EncoderRepetitionPenaltyLogitsProcessor to generation/__init__.py * Fix repo-inconsistency * Remove venv * Remove tensorflow-macos & add tests * Add documentation * Fix quality issues * move encoder_repetition_penalty to config * Update src/transformers/configuration_utils.py Co-authored-by: Joao Gante * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante * Remove encoder_repetition_penalty from tests * Fix type error * Fix format error Co-authored-by: Joao Gante --- src/transformers/generation/__init__.py | 2 ++ .../generation/configuration_utils.py | 4 +++ src/transformers/generation/logits_process.py | 28 +++++++++++++++++++ src/transformers/generation/utils.py | 10 +++++++ tests/generation/test_logits_process.py | 26 +++++++++++++++++ 5 files changed, 70 insertions(+) diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index 95d4c9fff7..f820c50133 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -58,6 +58,7 @@ else: "NoRepeatNGramLogitsProcessor", "PrefixConstrainedLogitsProcessor", "RepetitionPenaltyLogitsProcessor", + "EncoderRepetitionPenaltyLogitsProcessor", "TemperatureLogitsWarper", "TopKLogitsWarper", "TopPLogitsWarper", @@ -164,6 +165,7 @@ if TYPE_CHECKING: from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .logits_process import ( EncoderNoRepeatNGramLogitsProcessor, + EncoderRepetitionPenaltyLogitsProcessor, EpsilonLogitsWarper, EtaLogitsWarper, ExponentialDecayLengthPenalty, diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index aacb1d29f4..7d6118892c 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -127,6 +127,9 @@ class GenerationConfig(PushToHubMixin): repetition_penalty (`float`, *optional*, defaults to 1.0): The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + encoder_repetition_penalty (`float`, *optional*, defaults to 1.0): + The paramater for encoder_repetition_penalty. An exponential penalty on sequences that are not in the + original input. 1.0 means no penalty. length_penalty (`float`, *optional*, defaults to 1.0): Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log @@ -239,6 +242,7 @@ class GenerationConfig(PushToHubMixin): self.eta_cutoff = kwargs.pop("eta_cutoff", 0.0) self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0) self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0) + self.encoder_repetition_penalty = kwargs.pop("encoder_repetition_penalty", 1.0) self.length_penalty = kwargs.pop("length_penalty", 1.0) self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) self.bad_words_ids = kwargs.pop("bad_words_ids", None) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index e5b360a1c8..15e35bd21e 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -204,6 +204,34 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor): return scores +class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] enforcing an exponential penalty on tokens that are not in the original input. + + Args: + hallucination_penalty (`float`): + The parameter for hallucination penalty. 1.0 means no penalty. + encoder_input_ids (`torch.LongTensor`): + The encoder_input_ids that should not be repeated within the decoder ids. + """ + + def __init__(self, penalty: float, encoder_input_ids: torch.LongTensor): + if not isinstance(penalty, float) or not (penalty > 0): + raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") + + self.penalty = 1 / penalty + self.encoder_input_ids = encoder_input_ids + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + score = torch.gather(scores, 1, self.encoder_input_ids) + + # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability + score = torch.where(score < 0, score * self.penalty, score / self.penalty) + + scores.scatter_(1, self.encoder_input_ids, score) + return scores + + class TopPLogitsWarper(LogitsWarper): """ [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1ce610eabd..b47c4db3e3 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -39,6 +39,7 @@ from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScor from .configuration_utils import GenerationConfig from .logits_process import ( EncoderNoRepeatNGramLogitsProcessor, + EncoderRepetitionPenaltyLogitsProcessor, EpsilonLogitsWarper, EtaLogitsWarper, ExponentialDecayLengthPenalty, @@ -799,6 +800,15 @@ class GenerationMixin: num_beam_groups=generation_config.num_beam_groups, ) ) + if ( + generation_config.encoder_repetition_penalty is not None + and generation_config.encoder_repetition_penalty != 1.0 + ): + processors.append( + EncoderRepetitionPenaltyLogitsProcessor( + penalty=generation_config.encoder_repetition_penalty, encoder_input_ids=encoder_input_ids + ) + ) if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0: processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty)) if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0: diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index e81a5c865f..c377a23e7a 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -28,6 +28,7 @@ if is_torch_available(): from transformers.generation import ( EncoderNoRepeatNGramLogitsProcessor, + EncoderRepetitionPenaltyLogitsProcessor, EpsilonLogitsWarper, EtaLogitsWarper, ExponentialDecayLengthPenalty, @@ -175,6 +176,31 @@ class LogitsProcessorTest(unittest.TestCase): self.assertAlmostEqual(scores[1, 0].item(), (1 / vocab_size) / 2) self.assertAlmostEqual(scores[1, 5].item(), (4 / vocab_size) / 2) + def test_encoder_repetition_penalty_dist_process(self): + input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long) + vocab_size = 10 + + scores = self._get_uniform_logits(batch_size=2, length=vocab_size) + + # give values special values + scores[0, 0] = -(1 / vocab_size) + scores[1, 5] = 4 / vocab_size + + rep_penalty_proc = EncoderRepetitionPenaltyLogitsProcessor(penalty=2.0, encoder_input_ids=input_ids) + + scores = rep_penalty_proc(input_ids, scores.clone()) + + # check that values were correctly changed + self.assertAlmostEqual(scores[0, 0].item(), -(1 / vocab_size) / 2) + self.assertAlmostEqual(scores[0, 1].item(), (1 / vocab_size) * 2) + + self.assertAlmostEqual(scores[1, 0].item(), (1 / vocab_size) * 2) + self.assertAlmostEqual(scores[1, 5].item(), (4 / vocab_size) * 2) + + # check that values not in the encoder ids were NOT changed + self.assertAlmostEqual(scores[0, 2].item(), (1 / vocab_size)) + self.assertAlmostEqual(scores[1, 2].item(), (1 / vocab_size)) + def test_top_k_dist_warper(self): input_ids = None vocab_size = 10