Add hallucination filter (#18675)
* Add hallucination penalty * Make quality changes * Inverse penalty * Fix imports & quality * Fix name spelling issue * set encoder_repetition_penalty and fix quality * Fix failing test * Add to config_common_kwargs * Fix modelling_rag error * Update src/transformers/generation_logits_process.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Remove breakpoint * Make style fixes * Update encoder_repetition_penalty default value * Merge latest main changes * Make fixup changes * Add EncoderRepetitionPenaltyLogitsProcessor to generation/__init__.py * Fix repo-inconsistency * Remove venv * Remove tensorflow-macos & add tests * Add documentation * Fix quality issues * move encoder_repetition_penalty to config * Update src/transformers/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Remove encoder_repetition_penalty from tests * Fix type error * Fix format error Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
@@ -28,6 +28,7 @@ if is_torch_available():
|
||||
|
||||
from transformers.generation import (
|
||||
EncoderNoRepeatNGramLogitsProcessor,
|
||||
EncoderRepetitionPenaltyLogitsProcessor,
|
||||
EpsilonLogitsWarper,
|
||||
EtaLogitsWarper,
|
||||
ExponentialDecayLengthPenalty,
|
||||
@@ -175,6 +176,31 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
self.assertAlmostEqual(scores[1, 0].item(), (1 / vocab_size) / 2)
|
||||
self.assertAlmostEqual(scores[1, 5].item(), (4 / vocab_size) / 2)
|
||||
|
||||
def test_encoder_repetition_penalty_dist_process(self):
|
||||
input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long)
|
||||
vocab_size = 10
|
||||
|
||||
scores = self._get_uniform_logits(batch_size=2, length=vocab_size)
|
||||
|
||||
# give values special values
|
||||
scores[0, 0] = -(1 / vocab_size)
|
||||
scores[1, 5] = 4 / vocab_size
|
||||
|
||||
rep_penalty_proc = EncoderRepetitionPenaltyLogitsProcessor(penalty=2.0, encoder_input_ids=input_ids)
|
||||
|
||||
scores = rep_penalty_proc(input_ids, scores.clone())
|
||||
|
||||
# check that values were correctly changed
|
||||
self.assertAlmostEqual(scores[0, 0].item(), -(1 / vocab_size) / 2)
|
||||
self.assertAlmostEqual(scores[0, 1].item(), (1 / vocab_size) * 2)
|
||||
|
||||
self.assertAlmostEqual(scores[1, 0].item(), (1 / vocab_size) * 2)
|
||||
self.assertAlmostEqual(scores[1, 5].item(), (4 / vocab_size) * 2)
|
||||
|
||||
# check that values not in the encoder ids were NOT changed
|
||||
self.assertAlmostEqual(scores[0, 2].item(), (1 / vocab_size))
|
||||
self.assertAlmostEqual(scores[1, 2].item(), (1 / vocab_size))
|
||||
|
||||
def test_top_k_dist_warper(self):
|
||||
input_ids = None
|
||||
vocab_size = 10
|
||||
|
||||
Reference in New Issue
Block a user