From 13e03e619da0eded113be1cd1f4be9d4494033a0 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 14 Feb 2023 10:46:46 +0000 Subject: [PATCH] Generate: filter encoder inputs when its signature does not accept wildcards (#21603) --- src/transformers/generation/tf_utils.py | 12 +++-- src/transformers/generation/utils.py | 8 +++- tests/generation/test_tf_utils.py | 38 +++++++++++++++ tests/generation/test_utils.py | 64 +++++++++++++++---------- 4 files changed, 94 insertions(+), 28 deletions(-) diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index 1b18ed1c83..fe2d5ce541 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -1078,18 +1078,24 @@ class TFGenerationMixin: def _prepare_encoder_decoder_kwargs_for_generation( self, inputs_tensor: tf.Tensor, model_kwargs, model_input_name: Optional[str] = None ) -> Dict[str, Any]: - # get encoder and store encoder outputs + # 1. get encoder and store encoder outputs encoder = self.get_encoder() - # prepare encoder args and encoder kwargs from model kwargs + # 2. prepare encoder args and encoder kwargs from model kwargs irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] encoder_kwargs = { argument: value for argument, value in model_kwargs.items() if not any(argument.startswith(p) for p in irrelevant_prefix) } + encoder_signature = set(inspect.signature(encoder.call).parameters) + encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature + if not encoder_accepts_wildcard: + encoder_kwargs = { + argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature + } - # vision models don't use `attention_mask`. + # 3. vision models don't use `attention_mask`. encoder_kwargs["return_dict"] = True encoder_kwargs[model_input_name] = inputs_tensor if model_input_name != self.main_input_name: # in Keras, the first input must always be passed diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 014c66166f..ce955baef1 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -609,13 +609,19 @@ class GenerationMixin: # 1. get encoder encoder = self.get_encoder() - # 2. prepare encoder args and encoder kwargs from model kwargs + # 2. Prepare encoder args and encoder kwargs from model kwargs. irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] encoder_kwargs = { argument: value for argument, value in model_kwargs.items() if not any(argument.startswith(p) for p in irrelevant_prefix) } + encoder_signature = set(inspect.signature(encoder.forward).parameters) + encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature + if not encoder_accepts_wildcard: + encoder_kwargs = { + argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature + } # 3. make sure that encoder returns `ModelOutput` model_input_name = model_input_name if model_input_name is not None else self.main_input_name diff --git a/tests/generation/test_tf_utils.py b/tests/generation/test_tf_utils.py index dc3f14c7f2..cab4512bec 100644 --- a/tests/generation/test_tf_utils.py +++ b/tests/generation/test_tf_utils.py @@ -16,6 +16,8 @@ import tempfile import unittest +import numpy as np + from transformers import is_tf_available from transformers.testing_utils import require_tf, slow @@ -32,6 +34,7 @@ if is_tf_available(): TFAutoModelForSeq2SeqLM, TFAutoModelForSpeechSeq2Seq, TFAutoModelForVision2Seq, + TFBartForConditionalGeneration, TFLogitsProcessorList, TFMinLengthLogitsProcessor, tf_top_k_top_p_filtering, @@ -264,3 +267,38 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests tf.random.set_seed(0) generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) self.assertTrue(expectation == len(generated_tokens[0])) + + def test_model_kwarg_encoder_signature_filtering(self): + # Has PT equivalent: ample use of framework-specific code + bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") + article = """Hugging Face is a technology company based in New York and Paris.""" + input_ids = bart_tokenizer(article, return_tensors="tf").input_ids + bart_model = TFBartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart") + output = bart_model.generate(input_ids).numpy() + + # Let's create a fake model that has a different signature. In particular, this fake model accepts "foo" as an + # argument. Because "foo" is not in the encoder signature and doesn't start with "decoder_", it will be part of + # the encoder kwargs prior to signature filtering, which would lead to an exception. But filtering kicks in and + # saves the day. + class FakeBart(TFBartForConditionalGeneration): + def call(self, input_ids, foo=None, **kwargs): + return super().call(input_ids, **kwargs) + + bart_model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-bart") + fake_output = bart_model.generate(input_ids, foo="bar").numpy() + self.assertTrue(np.array_equal(output, fake_output)) + + # Encoder signature filtering only kicks in if it doesn't accept wildcard kwargs. The following test will fail + # because it doesn't do signature filtering. + class FakeEncoder(bart_model.model.encoder.__class__): + def call(self, input_ids, **kwargs): + return super().call(input_ids, **kwargs) + + fake_encoder = FakeEncoder(bart_model.config, bart_model.model.shared) + bart_model.model.encoder = fake_encoder + + # Normal generation still works (the output will be different because the encoder weights are different) + fake_output = bart_model.generate(input_ids).numpy() + with self.assertRaises(ValueError): + # FakeEncoder.call() accepts **kwargs -> no filtering -> value error due to unexpected input "foo" + bart_model.generate(input_ids, foo="bar") diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 88559e58b6..1287e4a876 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -17,6 +17,8 @@ import inspect import unittest +import numpy as np + from transformers import is_torch_available, pipeline from transformers.testing_utils import require_torch, slow, torch_device @@ -2439,30 +2441,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi max_score_diff = (output_sequences_batched.scores[0][1] - output_sequences.scores[0][0]).abs().max() self.assertTrue(max_score_diff < 1e-5) - def test_generate_from_input_embeds_decoder_only(self): - # PT-only test: TF doesn't have a model with support to generate from input embeds (yet ;)) - # Note: the model must support generation from input embeddings - model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2") - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") - - text = "Hello world" - input_ids = tokenizer.encode(text, return_tensors="pt") - - # Traditional way of generating text - outputs_from_ids = model.generate(input_ids) - - # Same thing, but from input embeddings - inputs_embeds = model.transformer.wte(input_ids) - outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds) - self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist()) - - # But if we pass different inputs_embeds, we should get different outputs - torch.manual_seed(0) - random_embeds = torch.rand_like(inputs_embeds) - outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds) - with self.assertRaises(AssertionError): - self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist()) - def test_eos_token_id_int_and_list_top_k_top_sampling(self): # Has TF equivalent: this test relies on random sampling generation_kwargs = { @@ -2490,6 +2468,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi self.assertTrue(expectation == len(generated_tokens[0])) def test_generate_from_inputs_embeds_decoder_only(self): + # PT-only test: TF doesn't have a model with support to generate from input embeds (yet ;)) # Note: the model must support generation from input embeddings model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") @@ -2523,3 +2502,40 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi outputs_from_embeds[:, inputs_embeds.shape[1] :].tolist(), outputs_from_embeds_wo_ids[:, 1:].tolist(), ) + + def test_model_kwarg_encoder_signature_filtering(self): + # Has TF equivalent: ample use of framework-specific code + bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") + article = """Hugging Face is a technology company based in New York and Paris.""" + input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( + torch_device + ) + output = bart_model.generate(input_ids).cpu().numpy() + + # Let's create a fake model that has a different signature. In particular, this fake model accepts "foo" as an + # argument. Because "foo" is not in the encoder signature and doesn't start with "decoder_", it will be part of + # the encoder kwargs prior to signature filtering, which would lead to an exception. But filtering kicks in and + # saves the day. + class FakeBart(BartForConditionalGeneration): + def forward(self, input_ids, foo=None, **kwargs): + return super().forward(input_ids, **kwargs) + + bart_model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-bart").to(torch_device) + fake_output = bart_model.generate(input_ids, foo="bar").cpu().numpy() + self.assertTrue(np.array_equal(output, fake_output)) + + # Encoder signature filtering only kicks in if it doesn't accept wildcard kwargs. The following test will fail + # because it doesn't do signature filtering. + class FakeEncoder(bart_model.model.encoder.__class__): + def forward(self, input_ids, **kwargs): + return super().forward(input_ids, **kwargs) + + fake_encoder = FakeEncoder(bart_model.config, bart_model.model.shared).to(torch_device) + bart_model.model.encoder = fake_encoder + + # Normal generation still works (the output will be different because the encoder weights are different) + fake_output = bart_model.generate(input_ids).cpu().numpy() + with self.assertRaises(TypeError): + # FakeEncoder.forward() accepts **kwargs -> no filtering -> type error due to unexpected input "foo" + bart_model.generate(input_ids, foo="bar")