From 162056a3f433132fa6ce32c87663742803036b59 Mon Sep 17 00:00:00 2001 From: Vladislav Bronzov <58587565+VladOS95-cyber@users.noreply.github.com> Date: Thu, 19 Sep 2024 18:35:44 +0200 Subject: [PATCH] =?UTF-8?q?change=20sequence=5Fbias=20type=20of=20Sequence?= =?UTF-8?q?BiasLogitsProcessor=20to=20list,=20add=E2=80=A6=20(#33375)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * change sequence_bias type of SequenceBiasLogitsProcessor tp list, add config tests for all processors * fix format * small fix for all_token_bias_pairs_are_valid internal func * small typo fix in description * improve test impl, some SequenceBiasLogitsProcessor refactoring --- src/transformers/generation/logits_process.py | 53 ++- tests/generation/test_configuration_utils.py | 449 +++++++++++++++++- 2 files changed, 486 insertions(+), 16 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index c586a97459..d88c7a17d8 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -15,7 +15,7 @@ import inspect import math -from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Callable, Iterable, List, Optional, Tuple, Union import numpy as np import torch @@ -1064,8 +1064,9 @@ class SequenceBiasLogitsProcessor(LogitsProcessor): Args: - sequence_bias (`Dict[Tuple[int], float]`): - Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the + sequence_bias (`List[List[Union[List[int], float]]]`): + List of lists that maps a sequence of tokens to its bias term (e.g. `[[[10, 45], -2.0], + [[64], -7.5]]`). Positive biases increase the odds of the sequence being selected, while negative biases do the opposite. If a sequence has a length of 1, its bias will always be applied. Otherwise, the bias will only be applied if the sequence in question is about to be completed (in the token selection step after this processor is applied). @@ -1087,12 +1088,12 @@ class SequenceBiasLogitsProcessor(LogitsProcessor): >>> tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("openai-community/gpt2", add_prefix_space=True) - >>> def get_tokens_as_tuple(word): - ... return tuple(tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0]) + >>> def get_tokens(word): + ... return tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0] >>> # If we add a negative bias without beam search, it may become "stuck" in a prefix without good continuations - >>> sequence_bias = {get_tokens_as_tuple("Trump"): -10.0} + >>> sequence_bias = [get_tokens("Trump"), -10.0] >>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, sequence_bias=sequence_bias) >>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0]) The full name of Donald is Donald J. Donald, @@ -1102,16 +1103,17 @@ class SequenceBiasLogitsProcessor(LogitsProcessor): The full name of Donald is Donald Rumsfeld, >>> # We can also add a positive bias to nudge the model towards specific tokens or continuations - >>> sequence_bias = {get_tokens_as_tuple("Donald Duck"): 10.0} + >>> sequence_bias = [get_tokens("Donald Duck"), 10.0] >>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias) >>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0]) The full name of Donald is Donald Duck. ``` """ - def __init__(self, sequence_bias: Dict[Tuple[int], float]): + def __init__(self, sequence_bias: List[List[Union[List[int], float]]]): self.sequence_bias = sequence_bias self._validate_arguments() + self._convert_list_arguments_into_dict() # Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size # is infered in the first usage, which inhibits initializing here) @@ -1178,11 +1180,15 @@ class SequenceBiasLogitsProcessor(LogitsProcessor): def _validate_arguments(self): sequence_bias = self.sequence_bias - if not isinstance(sequence_bias, dict) or len(sequence_bias) == 0: - raise ValueError(f"`sequence_bias` has to be a non-empty dictionary, but is {sequence_bias}.") - if any(not isinstance(sequence_ids, tuple) for sequence_ids in sequence_bias.keys()): + if not isinstance(sequence_bias, dict) and not isinstance(sequence_bias, list) or len(sequence_bias) == 0: + raise ValueError( + f"`sequence_bias` has to be a non-empty dictionary, or non-empty list of lists but is {sequence_bias}." + ) + if isinstance(sequence_bias, dict) and any( + not isinstance(sequence_ids, tuple) for sequence_ids in sequence_bias.keys() + ): raise ValueError(f"`sequence_bias` has to be a dict with tuples as keys, but is {sequence_bias}.") - if any( + if isinstance(sequence_bias, dict) and any( any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in sequence_ids) or len(sequence_ids) == 0 for sequence_ids in sequence_bias.keys() @@ -1191,9 +1197,30 @@ class SequenceBiasLogitsProcessor(LogitsProcessor): f"Each key in `sequence_bias` has to be a non-empty tuple of positive integers, but is " f"{sequence_bias}." ) - if any(not isinstance(bias, float) for bias in sequence_bias.values()): + + def all_token_bias_pairs_are_valid(sequence): + return ( + isinstance(sequence[0], list) + and all(isinstance(token_id, (int, np.integer)) and token_id > 0 for token_id in sequence[0]) + and isinstance(sequence[1], float) + ) + + if isinstance(sequence_bias, list) and any( + (not all_token_bias_pairs_are_valid(sequence)) or len(sequence) == 0 for sequence in sequence_bias + ): + raise ValueError( + f"Each element in `sequence_bias` has to be a non-empty list of lists of positive integers and float, but is " + f"{sequence_bias}." + ) + if isinstance(sequence_bias, dict) and any(not isinstance(bias, float) for bias in sequence_bias.values()): raise ValueError(f"`sequence_bias` has to be a dict with floats as values, but is {sequence_bias}.") + def _convert_list_arguments_into_dict(self): + """BC: we used to accept `dict{tuple of tokens: float}` directly, now we expect a list""" + if isinstance(self.sequence_bias, list): + temp_sequence = self.sequence_bias + self.sequence_bias = {tuple(sublist[0]): sublist[1] for sublist in temp_sequence} + class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor): """ diff --git a/tests/generation/test_configuration_utils.py b/tests/generation/test_configuration_utils.py index cd5f3d5016..1e11a9679b 100644 --- a/tests/generation/test_configuration_utils.py +++ b/tests/generation/test_configuration_utils.py @@ -23,9 +23,41 @@ from pathlib import Path from huggingface_hub import HfFolder, delete_repo from parameterized import parameterized -from transformers import AutoConfig, GenerationConfig -from transformers.generation import GenerationMode -from transformers.testing_utils import TOKEN, USER, is_staging_test +from transformers import AutoConfig, GenerationConfig, WatermarkingConfig, is_torch_available + + +if is_torch_available(): + import torch + +from transformers.generation import ( + ClassifierFreeGuidanceLogitsProcessor, + EncoderNoRepeatNGramLogitsProcessor, + EncoderRepetitionPenaltyLogitsProcessor, + EpsilonLogitsWarper, + EtaLogitsWarper, + ExponentialDecayLengthPenalty, + ForcedBOSTokenLogitsProcessor, + ForcedEOSTokenLogitsProcessor, + GenerationMode, + HammingDiversityLogitsProcessor, + MinLengthLogitsProcessor, + MinNewTokensLengthLogitsProcessor, + MinPLogitsWarper, + NoBadWordsLogitsProcessor, + NoRepeatNGramLogitsProcessor, + PrefixConstrainedLogitsProcessor, + RepetitionPenaltyLogitsProcessor, + SequenceBiasLogitsProcessor, + SuppressTokensAtBeginLogitsProcessor, + SuppressTokensLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + TypicalLogitsWarper, + UnbatchedClassifierFreeGuidanceLogitsProcessor, + WatermarkLogitsProcessor, +) +from transformers.testing_utils import TOKEN, USER, is_staging_test, torch_device class GenerationConfigTest(unittest.TestCase): @@ -225,6 +257,417 @@ class GenerationConfigTest(unittest.TestCase): self.assertEqual(config.get_generation_mode(assistant_model="foo"), GenerationMode.ASSISTED_GENERATION) +class GenerationConfigSerializationTest(unittest.TestCase): + def test_serialize_generation_sequence_bias(self): + """Tests that GenerationConfig is serialized and SequenceBiasLogitsProcessor is initialized with sequence_bias parameter""" + generation_config = GenerationConfig() + sequence_bias = [[[45, 67], -0.6], [[89], 1.2]] + generation_config.sequence_bias = sequence_bias + with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: + generation_config.save_pretrained(tmp_dir) + new_config = GenerationConfig.from_pretrained(tmp_dir) + self.assertSequenceEqual(new_config.sequence_bias, sequence_bias) + + expected_sequence_bias = {(45, 67): -0.6, (89,): 1.2} + bias_logits_processor = SequenceBiasLogitsProcessor(new_config.sequence_bias) + self.assertDictEqual(bias_logits_processor.sequence_bias, expected_sequence_bias) + + def test_serialize_generation_min_length_eos_token(self): + """Tests that GenerationConfig is serialized and MinLengthLogitsProcessor is initialized with min_length and eos_token_id""" + eos_token_id = 0 + min_length = 10 + + generation_config = GenerationConfig(min_length=min_length, eos_token_id=eos_token_id) + with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: + generation_config.save_pretrained(tmp_dir) + new_config = GenerationConfig.from_pretrained(tmp_dir) + self.assertEqual(new_config.min_length, min_length) + self.assertEqual(new_config.eos_token_id, eos_token_id) + + min_dist_processor = MinLengthLogitsProcessor( + min_length=new_config.min_length, eos_token_id=new_config.eos_token_id + ) + self.assertEqual(min_dist_processor.min_length, min_length) + self.assertEqual(min_dist_processor.eos_token_id, eos_token_id) + + def test_serialize_generation_min_new_tokens(self): + """Tests that GenerationConfig is serialized and MinNewTokensLengthLogitsProcessor is initialized with min_new_tokens""" + eos_token_id = 0 + min_new_tokens = 5 + prompt_length_to_skip = 2 + + generation_config = GenerationConfig(min_new_tokens=min_new_tokens) + with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: + generation_config.save_pretrained(tmp_dir) + new_config = GenerationConfig.from_pretrained(tmp_dir) + self.assertEqual(new_config.min_new_tokens, min_new_tokens) + + min_new_tokens_processor = MinNewTokensLengthLogitsProcessor( + prompt_length_to_skip=prompt_length_to_skip, + min_new_tokens=new_config.min_new_tokens, + eos_token_id=eos_token_id, + ) + self.assertEqual(min_new_tokens_processor.min_new_tokens, min_new_tokens) + + def test_serialize_generation_temperature(self): + """Tests that GenerationConfig is serialized and TemperatureLogitsWarper is initialized with temperature""" + temperature = 2.0 + + generation_config = GenerationConfig(temperature=temperature, do_sample=True) + with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: + generation_config.save_pretrained(tmp_dir) + new_config = GenerationConfig.from_pretrained(tmp_dir) + self.assertEqual(new_config.temperature, temperature) + + temperature_logits_warper = TemperatureLogitsWarper(temperature=new_config.temperature) + self.assertEqual(temperature_logits_warper.temperature, temperature) + + def test_serialize_generation_repetition_penalty(self): + """Tests that GenerationConfig is serialized and RepetitionPenaltyLogitsProcessor is initialized with repetition_penalty""" + penalty = 2.0 + + generation_config = GenerationConfig(repetition_penalty=penalty) + with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: + generation_config.save_pretrained(tmp_dir) + new_config = GenerationConfig.from_pretrained(tmp_dir) + self.assertEqual(new_config.repetition_penalty, penalty) + + rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=new_config.repetition_penalty) + self.assertEqual(rep_penalty_proc.penalty, penalty) + + def test_serialize_generation_encoder_repetition_penalty(self): + """Tests that GenerationConfig is serialized and EncoderRepetitionPenaltyLogitsProcessor is initialized with penalty and input_ids""" + penalty = 2.0 + input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long) + + generation_config = GenerationConfig(encoder_repetition_penalty=penalty) + with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: + generation_config.save_pretrained(tmp_dir) + new_config = GenerationConfig.from_pretrained(tmp_dir) + self.assertEqual(new_config.encoder_repetition_penalty, penalty) + + rep_penalty_proc = EncoderRepetitionPenaltyLogitsProcessor( + penalty=new_config.encoder_repetition_penalty, encoder_input_ids=input_ids + ) + self.assertEqual(rep_penalty_proc.penalty, 1 / penalty) + torch.testing.assert_close(rep_penalty_proc.encoder_input_ids, input_ids) + + def test_serialize_generation_top_p(self): + """Tests that GenerationConfig is serialized and TopPLogitsWarper is initialized with top_p""" + top_p = 0.8 + + generation_config = GenerationConfig(top_p=top_p, do_sample=True) + with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: + generation_config.save_pretrained(tmp_dir) + new_config = GenerationConfig.from_pretrained(tmp_dir) + self.assertEqual(new_config.top_p, top_p) + + rep_penalty_proc = TopPLogitsWarper(top_p=new_config.top_p) + self.assertEqual(rep_penalty_proc.top_p, top_p) + + def test_serialize_generation_top_k(self): + """Tests that GenerationConfig is serialized and TopKLogitsWarper is initialized with top_k""" + top_k = 2 + + generation_config = GenerationConfig(top_k=top_k, do_sample=True) + with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: + generation_config.save_pretrained(tmp_dir) + new_config = GenerationConfig.from_pretrained(tmp_dir) + self.assertEqual(new_config.top_k, top_k) + + top_k_logits_wrap = TopKLogitsWarper(top_k=new_config.top_k) + self.assertEqual(top_k_logits_wrap.top_k, top_k) + + def test_serialize_generation_min_p(self): + """Tests that GenerationConfig is serialized and MinPLogitsWarper is initialized with min_p""" + min_p = 0.8 + + generation_config = GenerationConfig(min_p=min_p, do_sample=True) + with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: + generation_config.save_pretrained(tmp_dir) + new_config = GenerationConfig.from_pretrained(tmp_dir) + self.assertEqual(new_config.min_p, min_p) + + min_k_logits_wrap = MinPLogitsWarper(min_p=new_config.min_p) + self.assertEqual(min_k_logits_wrap.min_p, min_p) + + def test_serialize_generation_typical_p(self): + """Tests that GenerationConfig is serialized and TypicalLogitsWarper is initialized with mass""" + mass = 0.8 + + generation_config = GenerationConfig(typical_p=mass, do_sample=True) + with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: + generation_config.save_pretrained(tmp_dir) + new_config = GenerationConfig.from_pretrained(tmp_dir) + self.assertEqual(new_config.typical_p, mass) + + typical_p_logits_wrap = TypicalLogitsWarper(mass=new_config.typical_p) + self.assertEqual(typical_p_logits_wrap.mass, mass) + + def test_serialize_generation_epsilon_cutoff(self): + """Tests that GenerationConfig is serialized and EpsilonLogitsWarper is initialized with epsilon""" + epsilon = 0.8 + + generation_config = GenerationConfig(epsilon_cutoff=epsilon, do_sample=True) + with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: + generation_config.save_pretrained(tmp_dir) + new_config = GenerationConfig.from_pretrained(tmp_dir) + self.assertEqual(new_config.epsilon_cutoff, epsilon) + + epsilon_logits_wrap = EpsilonLogitsWarper(epsilon=new_config.epsilon_cutoff) + self.assertEqual(epsilon_logits_wrap.epsilon, epsilon) + + def test_serialize_generation_eta_cutoff(self): + """Tests that GenerationConfig is serialized and EtaLogitsWarper is initialized with epsilon""" + epsilon = 0.8 + + generation_config = GenerationConfig(eta_cutoff=epsilon, do_sample=True) + with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: + generation_config.save_pretrained(tmp_dir) + new_config = GenerationConfig.from_pretrained(tmp_dir) + self.assertEqual(new_config.eta_cutoff, epsilon) + + eta_logits_wrap = EtaLogitsWarper(epsilon=new_config.eta_cutoff) + self.assertEqual(eta_logits_wrap.epsilon, epsilon) + + def test_serialize_generation_ngram_size(self): + """Tests that GenerationConfig is serialized and NoRepeatNGramLogitsProcessor is initialized with ngram_size""" + ngram_size = 2 + + generation_config = GenerationConfig(no_repeat_ngram_size=ngram_size, do_sample=True) + with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: + generation_config.save_pretrained(tmp_dir) + new_config = GenerationConfig.from_pretrained(tmp_dir) + self.assertEqual(new_config.no_repeat_ngram_size, ngram_size) + + no_repeat_ngram_proc = NoRepeatNGramLogitsProcessor(ngram_size=new_config.no_repeat_ngram_size) + self.assertEqual(no_repeat_ngram_proc.ngram_size, ngram_size) + + def test_serialize_generation_encoder_ngram_size(self): + """Tests that GenerationConfig is serialized and EncoderNoRepeatNGramLogitsProcessor is initialized with ngram_size""" + ngram_size = 2 + input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long) + + generation_config = GenerationConfig(encoder_no_repeat_ngram_size=ngram_size, do_sample=True) + with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: + generation_config.save_pretrained(tmp_dir) + new_config = GenerationConfig.from_pretrained(tmp_dir) + self.assertEqual(new_config.encoder_no_repeat_ngram_size, ngram_size) + + encoder_no_repeat_ngram_proc = EncoderNoRepeatNGramLogitsProcessor( + encoder_ngram_size=new_config.encoder_no_repeat_ngram_size, encoder_input_ids=input_ids + ) + self.assertEqual(encoder_no_repeat_ngram_proc.ngram_size, ngram_size) + + def test_serialize_generation_bad_words_ids(self): + """Tests that GenerationConfig is serialized and NoBadWordsLogitsProcessor is initialized with bad_words_ids""" + bad_word_tokens = [[1], [4], [1, 0], [0, 1, 2], [1, 3, 1, 3]] + + generation_config = GenerationConfig(bad_words_ids=bad_word_tokens) + with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: + generation_config.save_pretrained(tmp_dir) + new_config = GenerationConfig.from_pretrained(tmp_dir) + self.assertSequenceEqual(new_config.bad_words_ids, bad_word_tokens) + + no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=new_config.bad_words_ids) + self.assertSequenceEqual(no_bad_words_dist_proc.bad_word_ids, bad_word_tokens) + + def test_serialize_generation_num_beams(self): + """Tests that GenerationConfig is serialized and PrefixConstrainedLogitsProcessor is initialized with num_beams""" + num_beams = 1 + + def prefix_allowed_tokens_fn(batch_id, inputs_ids): + return [[0, 1], [2, 3]][batch_id] + + generation_config = GenerationConfig(num_beams=num_beams) + with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: + generation_config.save_pretrained(tmp_dir) + new_config = GenerationConfig.from_pretrained(tmp_dir) + self.assertEqual(new_config.num_beams, num_beams) + + prefix_constrained_logits_proc = PrefixConstrainedLogitsProcessor( + prefix_allowed_tokens_fn, num_beams=new_config.num_beams + ) + self.assertEqual(prefix_constrained_logits_proc._num_beams, num_beams) + + def test_serialize_generation_diversity_penalty_and_num_bean_groups(self): + """Tests that GenerationConfig is serialized and HammingDiversityLogitsProcessor is initialized with diversity_penalty_and_num_bean_groups""" + num_beams = 2 + num_beam_groups = 2 + diversity_penalty = 1.0 + + generation_config = GenerationConfig( + num_beams=num_beams, diversity_penalty=diversity_penalty, num_beam_groups=num_beam_groups + ) + with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: + generation_config.save_pretrained(tmp_dir) + new_config = GenerationConfig.from_pretrained(tmp_dir) + self.assertEqual(new_config.num_beams, num_beams) + self.assertEqual(new_config.diversity_penalty, diversity_penalty) + self.assertEqual(new_config.num_beam_groups, num_beam_groups) + + diversity_logits_processor = HammingDiversityLogitsProcessor( + diversity_penalty=new_config.diversity_penalty, + num_beams=new_config.num_beams, + num_beam_groups=new_config.num_beam_groups, + ) + self.assertEqual(diversity_logits_processor._num_beams, num_beams) + self.assertEqual(diversity_logits_processor._diversity_penalty, diversity_penalty) + self.assertEqual(diversity_logits_processor._num_sub_beams, num_beams // num_beam_groups) + + def test_serialize_generation_bos_token_id(self): + """Tests that GenerationConfig is serialized and ForcedBOSTokenLogitsProcessor is initialized with bos_token_id""" + bos_token_id = 0 + + generation_config = GenerationConfig(bos_token_id=bos_token_id) + with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: + generation_config.save_pretrained(tmp_dir) + new_config = GenerationConfig.from_pretrained(tmp_dir) + self.assertEqual(new_config.bos_token_id, bos_token_id) + + logits_processor = ForcedBOSTokenLogitsProcessor(bos_token_id=new_config.bos_token_id) + self.assertEqual(logits_processor.bos_token_id, bos_token_id) + + def test_serialize_generation_eos_token_id(self): + """Tests that GenerationConfig is serialized and ForcedEOSTokenLogitsProcessor is initialized with eos_token_id""" + eos_token_id = 0 + max_length = 5 + + generation_config = GenerationConfig(eos_token_id=eos_token_id) + with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: + generation_config.save_pretrained(tmp_dir) + new_config = GenerationConfig.from_pretrained(tmp_dir) + self.assertEqual(new_config.eos_token_id, eos_token_id) + + logits_processor = ForcedEOSTokenLogitsProcessor( + max_length=max_length, eos_token_id=new_config.eos_token_id, device=torch_device + ) + self.assertEqual(logits_processor.eos_token_id, eos_token_id) + + def test_serialize_generation_exponential_decay_length_penalty(self): + """Tests that GenerationConfig is serialized and ExponentialDecayLengthPenalty is initialized with regulation_start and regulation_factor""" + eos_token_id = 0 + penalty_start = 5 + penalty_factor = 1.1 + input_ids_seq_length = 10 + exponential_decay_length_penalty = (penalty_start, penalty_factor) + + generation_config = GenerationConfig(exponential_decay_length_penalty=exponential_decay_length_penalty) + with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: + generation_config.save_pretrained(tmp_dir) + new_config = GenerationConfig.from_pretrained(tmp_dir) + self.assertEqual(new_config.exponential_decay_length_penalty, [penalty_start, penalty_factor]) + + exponential_decay_processor = ExponentialDecayLengthPenalty( + exponential_decay_length_penalty=new_config.exponential_decay_length_penalty, + eos_token_id=eos_token_id, + input_ids_seq_length=input_ids_seq_length, + ) + self.assertEqual( + exponential_decay_processor.regulation_start, exponential_decay_length_penalty[0] + input_ids_seq_length + ) + self.assertEqual(exponential_decay_processor.regulation_factor, exponential_decay_length_penalty[1]) + + def test_serialize_generation_begin_suppress_tokens(self): + """Tests that GenerationConfig is serialized and SuppressTokensAtBeginLogitsProcessor is initialized with begin_suppress_token and begin_index""" + + begin_suppress_tokens = [220, 50256] + begin_index = 0 + generation_config = GenerationConfig(begin_suppress_tokens=begin_suppress_tokens) + with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: + generation_config.save_pretrained(tmp_dir) + new_config = GenerationConfig.from_pretrained(tmp_dir) + self.assertSequenceEqual(new_config.begin_suppress_tokens, begin_suppress_tokens) + + suppress_processor = SuppressTokensAtBeginLogitsProcessor( + begin_suppress_tokens=new_config.begin_suppress_tokens, begin_index=begin_index + ) + self.assertSequenceEqual(suppress_processor.begin_suppress_tokens, begin_suppress_tokens) + self.assertEqual(suppress_processor.begin_index, begin_index) + + def test_serialize_generation_suppress_tokens(self): + """Tests that GenerationConfig is serialized and SuppressTokensLogitsProcessor is initialized with suppress_token""" + suppress_tokens = [220, 50256] + + generation_config = GenerationConfig(suppress_tokens=suppress_tokens) + with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: + generation_config.save_pretrained(tmp_dir) + new_config = GenerationConfig.from_pretrained(tmp_dir) + self.assertSequenceEqual(new_config.suppress_tokens, suppress_tokens) + + suppress_processor = SuppressTokensLogitsProcessor(suppress_tokens=new_config.suppress_tokens) + self.assertSequenceEqual(suppress_processor.suppress_tokens, suppress_tokens) + + def test_serialize_generation_guidance_scale(self): + """Tests that GenerationConfig is serialized and ClassifierFreeGuidanceLogitsProcessor is initialized with guidance_scale""" + guidance_scale = 2.0 + generation_config = GenerationConfig(guidance_scale=guidance_scale) + with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: + generation_config.save_pretrained(tmp_dir) + new_config = GenerationConfig.from_pretrained(tmp_dir) + self.assertEqual(new_config.guidance_scale, guidance_scale) + + classifier_processor = ClassifierFreeGuidanceLogitsProcessor(guidance_scale=new_config.guidance_scale) + self.assertEqual(classifier_processor.guidance_scale, guidance_scale) + + def test_serialize_generation_guidance_scale_unbatched(self): + """Tests that GenerationConfig is serialized and UnbatchedClassifierFreeGuidanceLogitsProcessor is initialized with guidance_scale""" + guidance_scale = 2.0 + + input_ids = torch.LongTensor([[0]]) + + generation_config = GenerationConfig(guidance_scale=guidance_scale) + with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: + generation_config.save_pretrained(tmp_dir) + new_config = GenerationConfig.from_pretrained(tmp_dir) + self.assertEqual(new_config.guidance_scale, guidance_scale) + + cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(new_config.guidance_scale, {}, input_ids) + self.assertEqual(cfg.guidance_scale, guidance_scale) + + def test_serialize_generation_watermarking_config(self): + """Tests that GenerationConfig is serialized and WatermarkLogitsProcessor is initialized with WatermarkingConfig parameters""" + + vocab_size = 20 + bias = 2.0 + greenlist_ratio = 0.5 + hashing_key = 10 + seeding_scheme = "lefthash" + context_width = 10 + watermarking_config = WatermarkingConfig( + bias=bias, + greenlist_ratio=greenlist_ratio, + hashing_key=hashing_key, + seeding_scheme=seeding_scheme, + context_width=context_width, + ) + generation_config = GenerationConfig(watermarking_config=watermarking_config) + + with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: + generation_config.save_pretrained(tmp_dir) + new_config = GenerationConfig.from_pretrained(tmp_dir) + self.assertEqual(new_config.watermarking_config.bias, bias) + self.assertEqual(new_config.watermarking_config.greenlist_ratio, greenlist_ratio) + self.assertEqual(new_config.watermarking_config.hashing_key, hashing_key) + self.assertEqual(new_config.watermarking_config.seeding_scheme, seeding_scheme) + self.assertEqual(new_config.watermarking_config.context_width, context_width) + + watermark = WatermarkLogitsProcessor( + vocab_size=vocab_size, + device=torch_device, + greenlist_ratio=new_config.watermarking_config.greenlist_ratio, + bias=new_config.watermarking_config.bias, + hashing_key=new_config.watermarking_config.hashing_key, + seeding_scheme=new_config.watermarking_config.seeding_scheme, + context_width=new_config.watermarking_config.context_width, + ) + self.assertEqual(watermark.bias, bias) + self.assertEqual(watermark.greenlist_size, int(vocab_size * greenlist_ratio)) + self.assertEqual(watermark.hash_key, hashing_key) + self.assertEqual(watermark.seeding_scheme, seeding_scheme) + self.assertEqual(watermark.context_width, context_width) + + @is_staging_test class ConfigPushToHubTester(unittest.TestCase): @classmethod