Add hallucination filter (#18675)

* Add hallucination penalty

* Make quality changes

* Inverse penalty

* Fix imports & quality

* Fix name spelling issue

* set encoder_repetition_penalty and fix quality

* Fix failing test

* Add to config_common_kwargs

* Fix modelling_rag error

* Update src/transformers/generation_logits_process.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Remove breakpoint

* Make style fixes

* Update encoder_repetition_penalty default value

* Merge latest main changes

* Make fixup changes

* Add EncoderRepetitionPenaltyLogitsProcessor to generation/__init__.py

* Fix repo-inconsistency

* Remove venv

* Remove tensorflow-macos & add tests

* Add documentation

* Fix quality issues

* move encoder_repetition_penalty to config

* Update src/transformers/configuration_utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/configuration_utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Remove encoder_repetition_penalty from tests

* Fix type error

* Fix format error

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Karim Foda
2023-01-19 18:20:25 +02:00
committed by GitHub
parent e9b4800dda
commit b9403e9516
5 changed files with 70 additions and 0 deletions

View File

@@ -58,6 +58,7 @@ else:
"NoRepeatNGramLogitsProcessor", "NoRepeatNGramLogitsProcessor",
"PrefixConstrainedLogitsProcessor", "PrefixConstrainedLogitsProcessor",
"RepetitionPenaltyLogitsProcessor", "RepetitionPenaltyLogitsProcessor",
"EncoderRepetitionPenaltyLogitsProcessor",
"TemperatureLogitsWarper", "TemperatureLogitsWarper",
"TopKLogitsWarper", "TopKLogitsWarper",
"TopPLogitsWarper", "TopPLogitsWarper",
@@ -164,6 +165,7 @@ if TYPE_CHECKING:
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .logits_process import ( from .logits_process import (
EncoderNoRepeatNGramLogitsProcessor, EncoderNoRepeatNGramLogitsProcessor,
EncoderRepetitionPenaltyLogitsProcessor,
EpsilonLogitsWarper, EpsilonLogitsWarper,
EtaLogitsWarper, EtaLogitsWarper,
ExponentialDecayLengthPenalty, ExponentialDecayLengthPenalty,

View File

@@ -127,6 +127,9 @@ class GenerationConfig(PushToHubMixin):
repetition_penalty (`float`, *optional*, defaults to 1.0): repetition_penalty (`float`, *optional*, defaults to 1.0):
The parameter for repetition penalty. 1.0 means no penalty. See [this The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
encoder_repetition_penalty (`float`, *optional*, defaults to 1.0):
The paramater for encoder_repetition_penalty. An exponential penalty on sequences that are not in the
original input. 1.0 means no penalty.
length_penalty (`float`, *optional*, defaults to 1.0): length_penalty (`float`, *optional*, defaults to 1.0):
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
@@ -239,6 +242,7 @@ class GenerationConfig(PushToHubMixin):
self.eta_cutoff = kwargs.pop("eta_cutoff", 0.0) self.eta_cutoff = kwargs.pop("eta_cutoff", 0.0)
self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0) self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0) self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
self.encoder_repetition_penalty = kwargs.pop("encoder_repetition_penalty", 1.0)
self.length_penalty = kwargs.pop("length_penalty", 1.0) self.length_penalty = kwargs.pop("length_penalty", 1.0)
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
self.bad_words_ids = kwargs.pop("bad_words_ids", None) self.bad_words_ids = kwargs.pop("bad_words_ids", None)

View File

@@ -204,6 +204,34 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
return scores return scores
class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] enforcing an exponential penalty on tokens that are not in the original input.
Args:
hallucination_penalty (`float`):
The parameter for hallucination penalty. 1.0 means no penalty.
encoder_input_ids (`torch.LongTensor`):
The encoder_input_ids that should not be repeated within the decoder ids.
"""
def __init__(self, penalty: float, encoder_input_ids: torch.LongTensor):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
self.penalty = 1 / penalty
self.encoder_input_ids = encoder_input_ids
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
score = torch.gather(scores, 1, self.encoder_input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores.scatter_(1, self.encoder_input_ids, score)
return scores
class TopPLogitsWarper(LogitsWarper): class TopPLogitsWarper(LogitsWarper):
""" """
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.

View File

@@ -39,6 +39,7 @@ from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScor
from .configuration_utils import GenerationConfig from .configuration_utils import GenerationConfig
from .logits_process import ( from .logits_process import (
EncoderNoRepeatNGramLogitsProcessor, EncoderNoRepeatNGramLogitsProcessor,
EncoderRepetitionPenaltyLogitsProcessor,
EpsilonLogitsWarper, EpsilonLogitsWarper,
EtaLogitsWarper, EtaLogitsWarper,
ExponentialDecayLengthPenalty, ExponentialDecayLengthPenalty,
@@ -799,6 +800,15 @@ class GenerationMixin:
num_beam_groups=generation_config.num_beam_groups, num_beam_groups=generation_config.num_beam_groups,
) )
) )
if (
generation_config.encoder_repetition_penalty is not None
and generation_config.encoder_repetition_penalty != 1.0
):
processors.append(
EncoderRepetitionPenaltyLogitsProcessor(
penalty=generation_config.encoder_repetition_penalty, encoder_input_ids=encoder_input_ids
)
)
if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0: if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty)) processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty))
if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0: if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0:

View File

@@ -28,6 +28,7 @@ if is_torch_available():
from transformers.generation import ( from transformers.generation import (
EncoderNoRepeatNGramLogitsProcessor, EncoderNoRepeatNGramLogitsProcessor,
EncoderRepetitionPenaltyLogitsProcessor,
EpsilonLogitsWarper, EpsilonLogitsWarper,
EtaLogitsWarper, EtaLogitsWarper,
ExponentialDecayLengthPenalty, ExponentialDecayLengthPenalty,
@@ -175,6 +176,31 @@ class LogitsProcessorTest(unittest.TestCase):
self.assertAlmostEqual(scores[1, 0].item(), (1 / vocab_size) / 2) self.assertAlmostEqual(scores[1, 0].item(), (1 / vocab_size) / 2)
self.assertAlmostEqual(scores[1, 5].item(), (4 / vocab_size) / 2) self.assertAlmostEqual(scores[1, 5].item(), (4 / vocab_size) / 2)
def test_encoder_repetition_penalty_dist_process(self):
input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long)
vocab_size = 10
scores = self._get_uniform_logits(batch_size=2, length=vocab_size)
# give values special values
scores[0, 0] = -(1 / vocab_size)
scores[1, 5] = 4 / vocab_size
rep_penalty_proc = EncoderRepetitionPenaltyLogitsProcessor(penalty=2.0, encoder_input_ids=input_ids)
scores = rep_penalty_proc(input_ids, scores.clone())
# check that values were correctly changed
self.assertAlmostEqual(scores[0, 0].item(), -(1 / vocab_size) / 2)
self.assertAlmostEqual(scores[0, 1].item(), (1 / vocab_size) * 2)
self.assertAlmostEqual(scores[1, 0].item(), (1 / vocab_size) * 2)
self.assertAlmostEqual(scores[1, 5].item(), (4 / vocab_size) * 2)
# check that values not in the encoder ids were NOT changed
self.assertAlmostEqual(scores[0, 2].item(), (1 / vocab_size))
self.assertAlmostEqual(scores[1, 2].item(), (1 / vocab_size))
def test_top_k_dist_warper(self): def test_top_k_dist_warper(self):
input_ids = None input_ids = None
vocab_size = 10 vocab_size = 10