From 4943331015329d40b2cf60721c580aa3617d45e3 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 6 Feb 2023 15:44:47 +0000 Subject: [PATCH] Generate: TF can now accept custom logits processors (#21454) --- src/transformers/generation/tf_utils.py | 32 +++++++++++++++++++ .../models/rag/modeling_tf_rag.py | 7 ++++ tests/generation/test_framework_agnostic.py | 22 +++++++++++++ tests/generation/test_tf_utils.py | 10 +++++- tests/generation/test_utils.py | 29 +++++++---------- 5 files changed, 81 insertions(+), 19 deletions(-) diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index c06e6132ec..ec246aaff0 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -532,6 +532,7 @@ class TFGenerationMixin: self, input_ids: Optional[tf.Tensor] = None, generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[TFLogitsProcessorList] = None, seed=None, **kwargs, ) -> Union[TFGenerateOutput, tf.Tensor]: @@ -560,6 +561,10 @@ class TFGenerationMixin: priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. seed (`List[int]`, *optional*): Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the `seed` argument from stateless functions in `tf.random`. @@ -638,6 +643,8 @@ class TFGenerationMixin: model_kwargs["decoder_input_ids"] = tf.cast(model_kwargs["decoder_input_ids"], tf.int32) # 3. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList() + if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: if model_kwargs.get("attention_mask") is None: logger.warning( @@ -755,6 +762,7 @@ class TFGenerationMixin: logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_seq_length, + logits_processor=logits_processor, ) # 10. go into different generation modes @@ -1194,6 +1202,7 @@ class TFGenerationMixin: self, generation_config: GenerationConfig, input_ids_seq_length: int, + logits_processor: Optional[TFLogitsProcessorList], ) -> TFLogitsProcessorList: """ This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsProcessor`] @@ -1240,8 +1249,31 @@ class TFGenerationMixin: ) if generation_config.forced_decoder_ids is not None: processors.append(TFForceTokensLogitsProcessor(generation_config.forced_decoder_ids)) + + processors = self._merge_criteria_processor_list(processors, logits_processor) return processors + def _merge_criteria_processor_list( + self, + default_list: TFLogitsProcessorList, + custom_list: TFLogitsProcessorList, + ) -> TFLogitsProcessorList: + if len(custom_list) == 0: + return default_list + for default in default_list: + for custom in custom_list: + if type(custom) is type(default): + object_type = "logits processor" + raise ValueError( + f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to" + f" `generate`, but it has already been created with the values {default}. {default} has been" + " created by passing the corresponding arguments to generate or by the model's config default" + f" values. If you just want to change the default values of {object_type} consider passing" + f" them as arguments to `generate` instead of using a custom {object_type}." + ) + default_list.extend(custom_list) + return default_list + def greedy_search( self, input_ids: tf.Tensor, diff --git a/src/transformers/models/rag/modeling_tf_rag.py b/src/transformers/models/rag/modeling_tf_rag.py index 81c9d94c1a..cda15a8b44 100644 --- a/src/transformers/models/rag/modeling_tf_rag.py +++ b/src/transformers/models/rag/modeling_tf_rag.py @@ -23,6 +23,7 @@ import numpy as np import tensorflow as tf from ...configuration_utils import PretrainedConfig +from ...generation import TFLogitsProcessorList from ...modeling_tf_utils import ( TFCausalLanguageModelingLoss, TFModelInputType, @@ -1002,6 +1003,7 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss doc_scores=None, n_docs=None, generation_config=None, + logits_processor=TFLogitsProcessorList(), **kwargs ): """ @@ -1045,6 +1047,10 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s default values, whose documentation should be checked to parameterize generation. + logits_processor (`TFLogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and a + model's config. If a logit processor is passed that is already created with the arguments or a model's + config an error is thrown. kwargs: Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. @@ -1149,6 +1155,7 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss pre_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=tf.shape(decoder_input_ids)[-1], + logits_processor=logits_processor, ) if generation_config.num_beams == 1: diff --git a/tests/generation/test_framework_agnostic.py b/tests/generation/test_framework_agnostic.py index 31cc78d411..014ed4af1e 100644 --- a/tests/generation/test_framework_agnostic.py +++ b/tests/generation/test_framework_agnostic.py @@ -12,6 +12,8 @@ class GenerationIntegrationTestsMixin: # To be populated by the child classes framework_dependent_parameters = { "AutoModelForSeq2SeqLM": None, + "LogitsProcessorList": None, + "MinLengthLogitsProcessor": None, "create_tensor_fn": None, "return_tensors": None, } @@ -39,3 +41,23 @@ class GenerationIntegrationTestsMixin: # however, valid model_kwargs are accepted valid_model_kwargs = {"attention_mask": create_tensor_fn(np.zeros_like(input_ids))} model.generate(input_ids, **valid_model_kwargs) + + def test_custom_logits_processor(self): + model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"] + logits_processor_list_cls = self.framework_dependent_parameters["LogitsProcessorList"] + min_length_logits_processor_cls = self.framework_dependent_parameters["MinLengthLogitsProcessor"] + return_tensors = self.framework_dependent_parameters["return_tensors"] + + bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") + article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" + bart_model = model_cls.from_pretrained("hf-internal-testing/tiny-random-bart", min_length=1) + input_ids = bart_tokenizer(article, return_tensors=return_tensors).input_ids + + logits_processor = logits_processor_list_cls() + logits_processor.append(min_length_logits_processor_cls(min_length=10, eos_token_id=0)) + # it should not be allowed to both define `min_length` via config and `logits_processor` list + with self.assertRaises(ValueError): + bart_model.generate(input_ids, logits_processor=logits_processor) + + bart_model.config.min_length = None + bart_model.generate(input_ids, logits_processor=logits_processor) diff --git a/tests/generation/test_tf_utils.py b/tests/generation/test_tf_utils.py index 42eac59e50..7d50057189 100644 --- a/tests/generation/test_tf_utils.py +++ b/tests/generation/test_tf_utils.py @@ -25,7 +25,13 @@ from .test_framework_agnostic import GenerationIntegrationTestsMixin if is_tf_available(): import tensorflow as tf - from transformers import TFAutoModelForCausalLM, TFAutoModelForSeq2SeqLM, tf_top_k_top_p_filtering + from transformers import ( + TFAutoModelForCausalLM, + TFAutoModelForSeq2SeqLM, + TFLogitsProcessorList, + TFMinLengthLogitsProcessor, + tf_top_k_top_p_filtering, + ) @require_tf @@ -132,6 +138,8 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests if is_tf_available(): framework_dependent_parameters = { "AutoModelForSeq2SeqLM": TFAutoModelForSeq2SeqLM, + "LogitsProcessorList": TFLogitsProcessorList, + "MinLengthLogitsProcessor": TFMinLengthLogitsProcessor, "create_tensor_fn": tf.convert_to_tensor, "return_tensors": "tf", } diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index cb1c2460db..5bbefef8f1 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1797,12 +1797,15 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi if is_torch_available(): framework_dependent_parameters = { "AutoModelForSeq2SeqLM": AutoModelForSeq2SeqLM, + "LogitsProcessorList": LogitsProcessorList, + "MinLengthLogitsProcessor": MinLengthLogitsProcessor, "create_tensor_fn": torch.tensor, "return_tensors": "pt", } @slow def test_diverse_beam_search(self): + # PT-only test: TF doesn't have a diverse beam search implementation article = """Justin Timberlake and Jessica Biel, welcome to parenthood. The celebrity couple announced the arrival of their son, Silas Randall Timberlake, in statements to People. "Silas was the middle name of Timberlake's maternal grandfather Bill Bomar, who died in 2012, while Randall is the musician's own middle name, as well as his father's first," People reports. @@ -1836,6 +1839,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ) def test_max_length_backward_compat_greedy(self): + # PT-only test: TF doesn't have StoppingCriteria article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( @@ -1862,6 +1866,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ) def test_max_length_backward_compat_sample(self): + # PT-only test: TF doesn't have StoppingCriteria article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( @@ -1888,6 +1893,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ) def test_max_length_backward_compat_beam_search(self): + # PT-only test: TF doesn't have StoppingCriteria article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( @@ -1918,6 +1924,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ) def test_max_length_backward_compat_group_beam_search(self): + # PT-only test: TF doesn't have StoppingCriteria & group beam search article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( @@ -1952,6 +1959,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ) def test_max_length_warning_if_different(self): + # PT-only test: TF doesn't have StoppingCriteria article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( @@ -2035,6 +2043,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ) def test_custom_stopping_criteria_overload_error(self): + # PT-only test: TF doesn't have StoppingCriteria article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) @@ -2048,6 +2057,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=32) def test_custom_stopping_criteria(self): + # PT-only test: TF doesn't have StoppingCriteria article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) @@ -2070,7 +2080,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ) def test_stop_sequence_stopping_criteria(self): - + # PT-only test: TF doesn't have StoppingCriteria prompt = """Hello I believe in""" generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart") output = generator(prompt) @@ -2088,23 +2098,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi output = generator(prompt, stop_sequence=" number") self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}]) - def test_custom_logits_processor(self): - bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") - article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" - bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random", min_length=1).to( - torch_device - ) - input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - - logits_processor = LogitsProcessorList() - logits_processor.append(MinLengthLogitsProcessor(min_length=10, eos_token_id=0)) - # it should not be allowed to both define `min_length` via config and `logits_processor` list - with self.assertRaises(ValueError): - bart_model.generate(input_ids, logits_processor=logits_processor) - - bart_model.config.min_length = None - bart_model.generate(input_ids, logits_processor=logits_processor) - def test_max_new_tokens_encoder_decoder(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")