Add hallucination filter (#18675)
* Add hallucination penalty * Make quality changes * Inverse penalty * Fix imports & quality * Fix name spelling issue * set encoder_repetition_penalty and fix quality * Fix failing test * Add to config_common_kwargs * Fix modelling_rag error * Update src/transformers/generation_logits_process.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Remove breakpoint * Make style fixes * Update encoder_repetition_penalty default value * Merge latest main changes * Make fixup changes * Add EncoderRepetitionPenaltyLogitsProcessor to generation/__init__.py * Fix repo-inconsistency * Remove venv * Remove tensorflow-macos & add tests * Add documentation * Fix quality issues * move encoder_repetition_penalty to config * Update src/transformers/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Remove encoder_repetition_penalty from tests * Fix type error * Fix format error Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
@@ -58,6 +58,7 @@ else:
|
|||||||
"NoRepeatNGramLogitsProcessor",
|
"NoRepeatNGramLogitsProcessor",
|
||||||
"PrefixConstrainedLogitsProcessor",
|
"PrefixConstrainedLogitsProcessor",
|
||||||
"RepetitionPenaltyLogitsProcessor",
|
"RepetitionPenaltyLogitsProcessor",
|
||||||
|
"EncoderRepetitionPenaltyLogitsProcessor",
|
||||||
"TemperatureLogitsWarper",
|
"TemperatureLogitsWarper",
|
||||||
"TopKLogitsWarper",
|
"TopKLogitsWarper",
|
||||||
"TopPLogitsWarper",
|
"TopPLogitsWarper",
|
||||||
@@ -164,6 +165,7 @@ if TYPE_CHECKING:
|
|||||||
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||||
from .logits_process import (
|
from .logits_process import (
|
||||||
EncoderNoRepeatNGramLogitsProcessor,
|
EncoderNoRepeatNGramLogitsProcessor,
|
||||||
|
EncoderRepetitionPenaltyLogitsProcessor,
|
||||||
EpsilonLogitsWarper,
|
EpsilonLogitsWarper,
|
||||||
EtaLogitsWarper,
|
EtaLogitsWarper,
|
||||||
ExponentialDecayLengthPenalty,
|
ExponentialDecayLengthPenalty,
|
||||||
|
|||||||
@@ -127,6 +127,9 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
repetition_penalty (`float`, *optional*, defaults to 1.0):
|
repetition_penalty (`float`, *optional*, defaults to 1.0):
|
||||||
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
encoder_repetition_penalty (`float`, *optional*, defaults to 1.0):
|
||||||
|
The paramater for encoder_repetition_penalty. An exponential penalty on sequences that are not in the
|
||||||
|
original input. 1.0 means no penalty.
|
||||||
length_penalty (`float`, *optional*, defaults to 1.0):
|
length_penalty (`float`, *optional*, defaults to 1.0):
|
||||||
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
|
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
|
||||||
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
|
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
|
||||||
@@ -239,6 +242,7 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
self.eta_cutoff = kwargs.pop("eta_cutoff", 0.0)
|
self.eta_cutoff = kwargs.pop("eta_cutoff", 0.0)
|
||||||
self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
|
self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
|
||||||
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
|
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
|
||||||
|
self.encoder_repetition_penalty = kwargs.pop("encoder_repetition_penalty", 1.0)
|
||||||
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
||||||
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
|
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
|
||||||
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
|
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
|
||||||
|
|||||||
@@ -204,6 +204,34 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
|||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
||||||
|
r"""
|
||||||
|
[`LogitsProcessor`] enforcing an exponential penalty on tokens that are not in the original input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hallucination_penalty (`float`):
|
||||||
|
The parameter for hallucination penalty. 1.0 means no penalty.
|
||||||
|
encoder_input_ids (`torch.LongTensor`):
|
||||||
|
The encoder_input_ids that should not be repeated within the decoder ids.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, penalty: float, encoder_input_ids: torch.LongTensor):
|
||||||
|
if not isinstance(penalty, float) or not (penalty > 0):
|
||||||
|
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
|
||||||
|
|
||||||
|
self.penalty = 1 / penalty
|
||||||
|
self.encoder_input_ids = encoder_input_ids
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
score = torch.gather(scores, 1, self.encoder_input_ids)
|
||||||
|
|
||||||
|
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||||
|
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||||
|
|
||||||
|
scores.scatter_(1, self.encoder_input_ids, score)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
class TopPLogitsWarper(LogitsWarper):
|
class TopPLogitsWarper(LogitsWarper):
|
||||||
"""
|
"""
|
||||||
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScor
|
|||||||
from .configuration_utils import GenerationConfig
|
from .configuration_utils import GenerationConfig
|
||||||
from .logits_process import (
|
from .logits_process import (
|
||||||
EncoderNoRepeatNGramLogitsProcessor,
|
EncoderNoRepeatNGramLogitsProcessor,
|
||||||
|
EncoderRepetitionPenaltyLogitsProcessor,
|
||||||
EpsilonLogitsWarper,
|
EpsilonLogitsWarper,
|
||||||
EtaLogitsWarper,
|
EtaLogitsWarper,
|
||||||
ExponentialDecayLengthPenalty,
|
ExponentialDecayLengthPenalty,
|
||||||
@@ -799,6 +800,15 @@ class GenerationMixin:
|
|||||||
num_beam_groups=generation_config.num_beam_groups,
|
num_beam_groups=generation_config.num_beam_groups,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if (
|
||||||
|
generation_config.encoder_repetition_penalty is not None
|
||||||
|
and generation_config.encoder_repetition_penalty != 1.0
|
||||||
|
):
|
||||||
|
processors.append(
|
||||||
|
EncoderRepetitionPenaltyLogitsProcessor(
|
||||||
|
penalty=generation_config.encoder_repetition_penalty, encoder_input_ids=encoder_input_ids
|
||||||
|
)
|
||||||
|
)
|
||||||
if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
|
if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
|
||||||
processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty))
|
processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty))
|
||||||
if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0:
|
if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0:
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
from transformers.generation import (
|
from transformers.generation import (
|
||||||
EncoderNoRepeatNGramLogitsProcessor,
|
EncoderNoRepeatNGramLogitsProcessor,
|
||||||
|
EncoderRepetitionPenaltyLogitsProcessor,
|
||||||
EpsilonLogitsWarper,
|
EpsilonLogitsWarper,
|
||||||
EtaLogitsWarper,
|
EtaLogitsWarper,
|
||||||
ExponentialDecayLengthPenalty,
|
ExponentialDecayLengthPenalty,
|
||||||
@@ -175,6 +176,31 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
self.assertAlmostEqual(scores[1, 0].item(), (1 / vocab_size) / 2)
|
self.assertAlmostEqual(scores[1, 0].item(), (1 / vocab_size) / 2)
|
||||||
self.assertAlmostEqual(scores[1, 5].item(), (4 / vocab_size) / 2)
|
self.assertAlmostEqual(scores[1, 5].item(), (4 / vocab_size) / 2)
|
||||||
|
|
||||||
|
def test_encoder_repetition_penalty_dist_process(self):
|
||||||
|
input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long)
|
||||||
|
vocab_size = 10
|
||||||
|
|
||||||
|
scores = self._get_uniform_logits(batch_size=2, length=vocab_size)
|
||||||
|
|
||||||
|
# give values special values
|
||||||
|
scores[0, 0] = -(1 / vocab_size)
|
||||||
|
scores[1, 5] = 4 / vocab_size
|
||||||
|
|
||||||
|
rep_penalty_proc = EncoderRepetitionPenaltyLogitsProcessor(penalty=2.0, encoder_input_ids=input_ids)
|
||||||
|
|
||||||
|
scores = rep_penalty_proc(input_ids, scores.clone())
|
||||||
|
|
||||||
|
# check that values were correctly changed
|
||||||
|
self.assertAlmostEqual(scores[0, 0].item(), -(1 / vocab_size) / 2)
|
||||||
|
self.assertAlmostEqual(scores[0, 1].item(), (1 / vocab_size) * 2)
|
||||||
|
|
||||||
|
self.assertAlmostEqual(scores[1, 0].item(), (1 / vocab_size) * 2)
|
||||||
|
self.assertAlmostEqual(scores[1, 5].item(), (4 / vocab_size) * 2)
|
||||||
|
|
||||||
|
# check that values not in the encoder ids were NOT changed
|
||||||
|
self.assertAlmostEqual(scores[0, 2].item(), (1 / vocab_size))
|
||||||
|
self.assertAlmostEqual(scores[1, 2].item(), (1 / vocab_size))
|
||||||
|
|
||||||
def test_top_k_dist_warper(self):
|
def test_top_k_dist_warper(self):
|
||||||
input_ids = None
|
input_ids = None
|
||||||
vocab_size = 10
|
vocab_size = 10
|
||||||
|
|||||||
Reference in New Issue
Block a user