Generate: filter encoder inputs when its signature does not accept wildcards (#21603)
This commit is contained in:
@@ -16,6 +16,8 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import is_tf_available
|
||||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
@@ -32,6 +34,7 @@ if is_tf_available():
|
||||
TFAutoModelForSeq2SeqLM,
|
||||
TFAutoModelForSpeechSeq2Seq,
|
||||
TFAutoModelForVision2Seq,
|
||||
TFBartForConditionalGeneration,
|
||||
TFLogitsProcessorList,
|
||||
TFMinLengthLogitsProcessor,
|
||||
tf_top_k_top_p_filtering,
|
||||
@@ -264,3 +267,38 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
|
||||
tf.random.set_seed(0)
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
def test_model_kwarg_encoder_signature_filtering(self):
|
||||
# Has PT equivalent: ample use of framework-specific code
|
||||
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
article = """Hugging Face is a technology company based in New York and Paris."""
|
||||
input_ids = bart_tokenizer(article, return_tensors="tf").input_ids
|
||||
bart_model = TFBartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
output = bart_model.generate(input_ids).numpy()
|
||||
|
||||
# Let's create a fake model that has a different signature. In particular, this fake model accepts "foo" as an
|
||||
# argument. Because "foo" is not in the encoder signature and doesn't start with "decoder_", it will be part of
|
||||
# the encoder kwargs prior to signature filtering, which would lead to an exception. But filtering kicks in and
|
||||
# saves the day.
|
||||
class FakeBart(TFBartForConditionalGeneration):
|
||||
def call(self, input_ids, foo=None, **kwargs):
|
||||
return super().call(input_ids, **kwargs)
|
||||
|
||||
bart_model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
fake_output = bart_model.generate(input_ids, foo="bar").numpy()
|
||||
self.assertTrue(np.array_equal(output, fake_output))
|
||||
|
||||
# Encoder signature filtering only kicks in if it doesn't accept wildcard kwargs. The following test will fail
|
||||
# because it doesn't do signature filtering.
|
||||
class FakeEncoder(bart_model.model.encoder.__class__):
|
||||
def call(self, input_ids, **kwargs):
|
||||
return super().call(input_ids, **kwargs)
|
||||
|
||||
fake_encoder = FakeEncoder(bart_model.config, bart_model.model.shared)
|
||||
bart_model.model.encoder = fake_encoder
|
||||
|
||||
# Normal generation still works (the output will be different because the encoder weights are different)
|
||||
fake_output = bart_model.generate(input_ids).numpy()
|
||||
with self.assertRaises(ValueError):
|
||||
# FakeEncoder.call() accepts **kwargs -> no filtering -> value error due to unexpected input "foo"
|
||||
bart_model.generate(input_ids, foo="bar")
|
||||
|
||||
Reference in New Issue
Block a user