Generate: filter encoder inputs when its signature does not accept wildcards (#21603)
This commit is contained in:
@@ -1078,18 +1078,24 @@ class TFGenerationMixin:
|
|||||||
def _prepare_encoder_decoder_kwargs_for_generation(
|
def _prepare_encoder_decoder_kwargs_for_generation(
|
||||||
self, inputs_tensor: tf.Tensor, model_kwargs, model_input_name: Optional[str] = None
|
self, inputs_tensor: tf.Tensor, model_kwargs, model_input_name: Optional[str] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
# get encoder and store encoder outputs
|
# 1. get encoder and store encoder outputs
|
||||||
encoder = self.get_encoder()
|
encoder = self.get_encoder()
|
||||||
|
|
||||||
# prepare encoder args and encoder kwargs from model kwargs
|
# 2. prepare encoder args and encoder kwargs from model kwargs
|
||||||
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
|
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
|
||||||
encoder_kwargs = {
|
encoder_kwargs = {
|
||||||
argument: value
|
argument: value
|
||||||
for argument, value in model_kwargs.items()
|
for argument, value in model_kwargs.items()
|
||||||
if not any(argument.startswith(p) for p in irrelevant_prefix)
|
if not any(argument.startswith(p) for p in irrelevant_prefix)
|
||||||
}
|
}
|
||||||
|
encoder_signature = set(inspect.signature(encoder.call).parameters)
|
||||||
|
encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
|
||||||
|
if not encoder_accepts_wildcard:
|
||||||
|
encoder_kwargs = {
|
||||||
|
argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
|
||||||
|
}
|
||||||
|
|
||||||
# vision models don't use `attention_mask`.
|
# 3. vision models don't use `attention_mask`.
|
||||||
encoder_kwargs["return_dict"] = True
|
encoder_kwargs["return_dict"] = True
|
||||||
encoder_kwargs[model_input_name] = inputs_tensor
|
encoder_kwargs[model_input_name] = inputs_tensor
|
||||||
if model_input_name != self.main_input_name: # in Keras, the first input must always be passed
|
if model_input_name != self.main_input_name: # in Keras, the first input must always be passed
|
||||||
|
|||||||
@@ -609,13 +609,19 @@ class GenerationMixin:
|
|||||||
# 1. get encoder
|
# 1. get encoder
|
||||||
encoder = self.get_encoder()
|
encoder = self.get_encoder()
|
||||||
|
|
||||||
# 2. prepare encoder args and encoder kwargs from model kwargs
|
# 2. Prepare encoder args and encoder kwargs from model kwargs.
|
||||||
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
|
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
|
||||||
encoder_kwargs = {
|
encoder_kwargs = {
|
||||||
argument: value
|
argument: value
|
||||||
for argument, value in model_kwargs.items()
|
for argument, value in model_kwargs.items()
|
||||||
if not any(argument.startswith(p) for p in irrelevant_prefix)
|
if not any(argument.startswith(p) for p in irrelevant_prefix)
|
||||||
}
|
}
|
||||||
|
encoder_signature = set(inspect.signature(encoder.forward).parameters)
|
||||||
|
encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
|
||||||
|
if not encoder_accepts_wildcard:
|
||||||
|
encoder_kwargs = {
|
||||||
|
argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
|
||||||
|
}
|
||||||
|
|
||||||
# 3. make sure that encoder returns `ModelOutput`
|
# 3. make sure that encoder returns `ModelOutput`
|
||||||
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
|
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
|
||||||
|
|||||||
@@ -16,6 +16,8 @@
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from transformers import is_tf_available
|
from transformers import is_tf_available
|
||||||
from transformers.testing_utils import require_tf, slow
|
from transformers.testing_utils import require_tf, slow
|
||||||
|
|
||||||
@@ -32,6 +34,7 @@ if is_tf_available():
|
|||||||
TFAutoModelForSeq2SeqLM,
|
TFAutoModelForSeq2SeqLM,
|
||||||
TFAutoModelForSpeechSeq2Seq,
|
TFAutoModelForSpeechSeq2Seq,
|
||||||
TFAutoModelForVision2Seq,
|
TFAutoModelForVision2Seq,
|
||||||
|
TFBartForConditionalGeneration,
|
||||||
TFLogitsProcessorList,
|
TFLogitsProcessorList,
|
||||||
TFMinLengthLogitsProcessor,
|
TFMinLengthLogitsProcessor,
|
||||||
tf_top_k_top_p_filtering,
|
tf_top_k_top_p_filtering,
|
||||||
@@ -264,3 +267,38 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
|
|||||||
tf.random.set_seed(0)
|
tf.random.set_seed(0)
|
||||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||||
|
|
||||||
|
def test_model_kwarg_encoder_signature_filtering(self):
|
||||||
|
# Has PT equivalent: ample use of framework-specific code
|
||||||
|
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||||
|
article = """Hugging Face is a technology company based in New York and Paris."""
|
||||||
|
input_ids = bart_tokenizer(article, return_tensors="tf").input_ids
|
||||||
|
bart_model = TFBartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||||
|
output = bart_model.generate(input_ids).numpy()
|
||||||
|
|
||||||
|
# Let's create a fake model that has a different signature. In particular, this fake model accepts "foo" as an
|
||||||
|
# argument. Because "foo" is not in the encoder signature and doesn't start with "decoder_", it will be part of
|
||||||
|
# the encoder kwargs prior to signature filtering, which would lead to an exception. But filtering kicks in and
|
||||||
|
# saves the day.
|
||||||
|
class FakeBart(TFBartForConditionalGeneration):
|
||||||
|
def call(self, input_ids, foo=None, **kwargs):
|
||||||
|
return super().call(input_ids, **kwargs)
|
||||||
|
|
||||||
|
bart_model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||||
|
fake_output = bart_model.generate(input_ids, foo="bar").numpy()
|
||||||
|
self.assertTrue(np.array_equal(output, fake_output))
|
||||||
|
|
||||||
|
# Encoder signature filtering only kicks in if it doesn't accept wildcard kwargs. The following test will fail
|
||||||
|
# because it doesn't do signature filtering.
|
||||||
|
class FakeEncoder(bart_model.model.encoder.__class__):
|
||||||
|
def call(self, input_ids, **kwargs):
|
||||||
|
return super().call(input_ids, **kwargs)
|
||||||
|
|
||||||
|
fake_encoder = FakeEncoder(bart_model.config, bart_model.model.shared)
|
||||||
|
bart_model.model.encoder = fake_encoder
|
||||||
|
|
||||||
|
# Normal generation still works (the output will be different because the encoder weights are different)
|
||||||
|
fake_output = bart_model.generate(input_ids).numpy()
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
# FakeEncoder.call() accepts **kwargs -> no filtering -> value error due to unexpected input "foo"
|
||||||
|
bart_model.generate(input_ids, foo="bar")
|
||||||
|
|||||||
@@ -17,6 +17,8 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from transformers import is_torch_available, pipeline
|
from transformers import is_torch_available, pipeline
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
@@ -2439,30 +2441,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
max_score_diff = (output_sequences_batched.scores[0][1] - output_sequences.scores[0][0]).abs().max()
|
max_score_diff = (output_sequences_batched.scores[0][1] - output_sequences.scores[0][0]).abs().max()
|
||||||
self.assertTrue(max_score_diff < 1e-5)
|
self.assertTrue(max_score_diff < 1e-5)
|
||||||
|
|
||||||
def test_generate_from_input_embeds_decoder_only(self):
|
|
||||||
# PT-only test: TF doesn't have a model with support to generate from input embeds (yet ;))
|
|
||||||
# Note: the model must support generation from input embeddings
|
|
||||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
|
||||||
|
|
||||||
text = "Hello world"
|
|
||||||
input_ids = tokenizer.encode(text, return_tensors="pt")
|
|
||||||
|
|
||||||
# Traditional way of generating text
|
|
||||||
outputs_from_ids = model.generate(input_ids)
|
|
||||||
|
|
||||||
# Same thing, but from input embeddings
|
|
||||||
inputs_embeds = model.transformer.wte(input_ids)
|
|
||||||
outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds)
|
|
||||||
self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist())
|
|
||||||
|
|
||||||
# But if we pass different inputs_embeds, we should get different outputs
|
|
||||||
torch.manual_seed(0)
|
|
||||||
random_embeds = torch.rand_like(inputs_embeds)
|
|
||||||
outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds)
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist())
|
|
||||||
|
|
||||||
def test_eos_token_id_int_and_list_top_k_top_sampling(self):
|
def test_eos_token_id_int_and_list_top_k_top_sampling(self):
|
||||||
# Has TF equivalent: this test relies on random sampling
|
# Has TF equivalent: this test relies on random sampling
|
||||||
generation_kwargs = {
|
generation_kwargs = {
|
||||||
@@ -2490,6 +2468,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||||
|
|
||||||
def test_generate_from_inputs_embeds_decoder_only(self):
|
def test_generate_from_inputs_embeds_decoder_only(self):
|
||||||
|
# PT-only test: TF doesn't have a model with support to generate from input embeds (yet ;))
|
||||||
# Note: the model must support generation from input embeddings
|
# Note: the model must support generation from input embeddings
|
||||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
@@ -2523,3 +2502,40 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
outputs_from_embeds[:, inputs_embeds.shape[1] :].tolist(),
|
outputs_from_embeds[:, inputs_embeds.shape[1] :].tolist(),
|
||||||
outputs_from_embeds_wo_ids[:, 1:].tolist(),
|
outputs_from_embeds_wo_ids[:, 1:].tolist(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_model_kwarg_encoder_signature_filtering(self):
|
||||||
|
# Has TF equivalent: ample use of framework-specific code
|
||||||
|
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||||
|
article = """Hugging Face is a technology company based in New York and Paris."""
|
||||||
|
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||||
|
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||||
|
torch_device
|
||||||
|
)
|
||||||
|
output = bart_model.generate(input_ids).cpu().numpy()
|
||||||
|
|
||||||
|
# Let's create a fake model that has a different signature. In particular, this fake model accepts "foo" as an
|
||||||
|
# argument. Because "foo" is not in the encoder signature and doesn't start with "decoder_", it will be part of
|
||||||
|
# the encoder kwargs prior to signature filtering, which would lead to an exception. But filtering kicks in and
|
||||||
|
# saves the day.
|
||||||
|
class FakeBart(BartForConditionalGeneration):
|
||||||
|
def forward(self, input_ids, foo=None, **kwargs):
|
||||||
|
return super().forward(input_ids, **kwargs)
|
||||||
|
|
||||||
|
bart_model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-bart").to(torch_device)
|
||||||
|
fake_output = bart_model.generate(input_ids, foo="bar").cpu().numpy()
|
||||||
|
self.assertTrue(np.array_equal(output, fake_output))
|
||||||
|
|
||||||
|
# Encoder signature filtering only kicks in if it doesn't accept wildcard kwargs. The following test will fail
|
||||||
|
# because it doesn't do signature filtering.
|
||||||
|
class FakeEncoder(bart_model.model.encoder.__class__):
|
||||||
|
def forward(self, input_ids, **kwargs):
|
||||||
|
return super().forward(input_ids, **kwargs)
|
||||||
|
|
||||||
|
fake_encoder = FakeEncoder(bart_model.config, bart_model.model.shared).to(torch_device)
|
||||||
|
bart_model.model.encoder = fake_encoder
|
||||||
|
|
||||||
|
# Normal generation still works (the output will be different because the encoder weights are different)
|
||||||
|
fake_output = bart_model.generate(input_ids).cpu().numpy()
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
# FakeEncoder.forward() accepts **kwargs -> no filtering -> type error due to unexpected input "foo"
|
||||||
|
bart_model.generate(input_ids, foo="bar")
|
||||||
|
|||||||
Reference in New Issue
Block a user