Template for framework-agnostic tests (#21348)

This commit is contained in:
Joao Gante
2023-01-31 11:33:18 +00:00
committed by GitHub
parent 5451f8896c
commit 623346ab18
4 changed files with 68 additions and 41 deletions

View File

@@ -19,11 +19,13 @@ import unittest
from transformers import is_tf_available
from transformers.testing_utils import require_tf, slow
from .test_framework_agnostic import GenerationIntegrationTestsMixin
if is_tf_available():
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForCausalLM, TFAutoModelForSeq2SeqLM, tf_top_k_top_p_filtering
from transformers import TFAutoModelForCausalLM, TFAutoModelForSeq2SeqLM, tf_top_k_top_p_filtering
@require_tf
@@ -124,7 +126,16 @@ class UtilsFunctionsTest(unittest.TestCase):
@require_tf
class TFGenerationIntegrationTests(unittest.TestCase):
class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMixin):
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
if is_tf_available():
framework_dependent_parameters = {
"AutoModelForSeq2SeqLM": TFAutoModelForSeq2SeqLM,
"create_tensor_fn": tf.convert_to_tensor,
"return_tensors": "tf",
}
@slow
def test_generate_tf_function_export(self):
test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
@@ -165,19 +176,3 @@ class TFGenerationIntegrationTests(unittest.TestCase):
tf_func_outputs = serving_func(**inputs)["sequences"]
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_length)
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)
def test_validate_generation_inputs(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
model = TFAutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
encoder_input_str = "Hello world"
input_ids = tokenizer(encoder_input_str, return_tensors="tf").input_ids
# typos are quickly detected (the correct argument is `do_sample`)
with self.assertRaisesRegex(ValueError, "do_samples"):
model.generate(input_ids, do_samples=True)
# arbitrary arguments that will not be used anywhere are also not accepted
with self.assertRaisesRegex(ValueError, "foo"):
fake_model_kwargs = {"foo": "bar"}
model.generate(input_ids, **fake_model_kwargs)