Template for framework-agnostic tests (#21348)
This commit is contained in:
41
tests/generation/test_framework_agnostic.py
Normal file
41
tests/generation/test_framework_agnostic.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
"""
|
||||||
|
Framework agnostic tests for generate()-related methods.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class GenerationIntegrationTestsMixin:
|
||||||
|
|
||||||
|
# To be populated by the child classes
|
||||||
|
framework_dependent_parameters = {
|
||||||
|
"AutoModelForSeq2SeqLM": None,
|
||||||
|
"create_tensor_fn": None,
|
||||||
|
"return_tensors": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_validate_generation_inputs(self):
|
||||||
|
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
|
||||||
|
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||||
|
create_tensor_fn = self.framework_dependent_parameters["create_tensor_fn"]
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||||
|
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||||
|
|
||||||
|
encoder_input_str = "Hello world"
|
||||||
|
input_ids = tokenizer(encoder_input_str, return_tensors=return_tensors).input_ids
|
||||||
|
|
||||||
|
# typos are quickly detected (the correct argument is `do_sample`)
|
||||||
|
with self.assertRaisesRegex(ValueError, "do_samples"):
|
||||||
|
model.generate(input_ids, do_samples=True)
|
||||||
|
|
||||||
|
# arbitrary arguments that will not be used anywhere are also not accepted
|
||||||
|
with self.assertRaisesRegex(ValueError, "foo"):
|
||||||
|
fake_model_kwargs = {"foo": "bar"}
|
||||||
|
model.generate(input_ids, **fake_model_kwargs)
|
||||||
|
|
||||||
|
# however, valid model_kwargs are accepted
|
||||||
|
valid_model_kwargs = {"attention_mask": create_tensor_fn(np.zeros_like(input_ids))}
|
||||||
|
model.generate(input_ids, **valid_model_kwargs)
|
||||||
@@ -19,11 +19,13 @@ import unittest
|
|||||||
from transformers import is_tf_available
|
from transformers import is_tf_available
|
||||||
from transformers.testing_utils import require_tf, slow
|
from transformers.testing_utils import require_tf, slow
|
||||||
|
|
||||||
|
from .test_framework_agnostic import GenerationIntegrationTestsMixin
|
||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from transformers import AutoTokenizer, TFAutoModelForCausalLM, TFAutoModelForSeq2SeqLM, tf_top_k_top_p_filtering
|
from transformers import TFAutoModelForCausalLM, TFAutoModelForSeq2SeqLM, tf_top_k_top_p_filtering
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
@@ -124,7 +126,16 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFGenerationIntegrationTests(unittest.TestCase):
|
class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMixin):
|
||||||
|
|
||||||
|
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
|
||||||
|
if is_tf_available():
|
||||||
|
framework_dependent_parameters = {
|
||||||
|
"AutoModelForSeq2SeqLM": TFAutoModelForSeq2SeqLM,
|
||||||
|
"create_tensor_fn": tf.convert_to_tensor,
|
||||||
|
"return_tensors": "tf",
|
||||||
|
}
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_generate_tf_function_export(self):
|
def test_generate_tf_function_export(self):
|
||||||
test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
@@ -165,19 +176,3 @@ class TFGenerationIntegrationTests(unittest.TestCase):
|
|||||||
tf_func_outputs = serving_func(**inputs)["sequences"]
|
tf_func_outputs = serving_func(**inputs)["sequences"]
|
||||||
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_length)
|
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_length)
|
||||||
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)
|
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)
|
||||||
|
|
||||||
def test_validate_generation_inputs(self):
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
|
||||||
model = TFAutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
|
|
||||||
|
|
||||||
encoder_input_str = "Hello world"
|
|
||||||
input_ids = tokenizer(encoder_input_str, return_tensors="tf").input_ids
|
|
||||||
|
|
||||||
# typos are quickly detected (the correct argument is `do_sample`)
|
|
||||||
with self.assertRaisesRegex(ValueError, "do_samples"):
|
|
||||||
model.generate(input_ids, do_samples=True)
|
|
||||||
|
|
||||||
# arbitrary arguments that will not be used anywhere are also not accepted
|
|
||||||
with self.assertRaisesRegex(ValueError, "foo"):
|
|
||||||
fake_model_kwargs = {"foo": "bar"}
|
|
||||||
model.generate(input_ids, **fake_model_kwargs)
|
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from transformers import is_torch_available, pipeline
|
|||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from ..test_modeling_common import floats_tensor, ids_tensor
|
from ..test_modeling_common import floats_tensor, ids_tensor
|
||||||
|
from .test_framework_agnostic import GenerationIntegrationTestsMixin
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -1790,7 +1791,16 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class GenerationIntegrationTests(unittest.TestCase):
|
class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMixin):
|
||||||
|
|
||||||
|
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
|
||||||
|
if is_torch_available():
|
||||||
|
framework_dependent_parameters = {
|
||||||
|
"AutoModelForSeq2SeqLM": AutoModelForSeq2SeqLM,
|
||||||
|
"create_tensor_fn": torch.tensor,
|
||||||
|
"return_tensors": "pt",
|
||||||
|
}
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_diverse_beam_search(self):
|
def test_diverse_beam_search(self):
|
||||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.
|
||||||
@@ -3022,26 +3032,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
max_score_diff = (output_sequences_batched.scores[0][1] - output_sequences.scores[0][0]).abs().max()
|
max_score_diff = (output_sequences_batched.scores[0][1] - output_sequences.scores[0][0]).abs().max()
|
||||||
self.assertTrue(max_score_diff < 1e-5)
|
self.assertTrue(max_score_diff < 1e-5)
|
||||||
|
|
||||||
def test_validate_generation_inputs(self):
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta")
|
|
||||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-roberta")
|
|
||||||
|
|
||||||
encoder_input_str = "Hello world"
|
|
||||||
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
|
|
||||||
|
|
||||||
# typos are quickly detected (the correct argument is `do_sample`)
|
|
||||||
with self.assertRaisesRegex(ValueError, "do_samples"):
|
|
||||||
model.generate(input_ids, do_samples=True)
|
|
||||||
|
|
||||||
# arbitrary arguments that will not be used anywhere are also not accepted
|
|
||||||
with self.assertRaisesRegex(ValueError, "foo"):
|
|
||||||
fake_model_kwargs = {"foo": "bar"}
|
|
||||||
model.generate(input_ids, **fake_model_kwargs)
|
|
||||||
|
|
||||||
# However, valid model_kwargs are accepted
|
|
||||||
valid_model_kwargs = {"attention_mask": torch.zeros_like(input_ids)}
|
|
||||||
model.generate(input_ids, **valid_model_kwargs)
|
|
||||||
|
|
||||||
def test_eos_token_id_int_and_list_greedy_search(self):
|
def test_eos_token_id_int_and_list_greedy_search(self):
|
||||||
generation_kwargs = {
|
generation_kwargs = {
|
||||||
"do_sample": False,
|
"do_sample": False,
|
||||||
|
|||||||
@@ -466,12 +466,13 @@ def module_to_test_file(module_fname):
|
|||||||
# This list contains the list of test files we expect never to be launched from a change in a module/util. Those are
|
# This list contains the list of test files we expect never to be launched from a change in a module/util. Those are
|
||||||
# launched separately.
|
# launched separately.
|
||||||
EXPECTED_TEST_FILES_NEVER_TOUCHED = [
|
EXPECTED_TEST_FILES_NEVER_TOUCHED = [
|
||||||
"tests/utils/test_doc_samples.py", # Doc tests
|
"tests/generation/test_framework_agnostic.py", # Mixins inherited by actual test classes
|
||||||
|
"tests/mixed_int8/test_mixed_int8.py", # Mixed-int8 bitsandbytes test
|
||||||
"tests/pipelines/test_pipelines_common.py", # Actually checked by the pipeline based file
|
"tests/pipelines/test_pipelines_common.py", # Actually checked by the pipeline based file
|
||||||
"tests/sagemaker/test_single_node_gpu.py", # SageMaker test
|
"tests/sagemaker/test_single_node_gpu.py", # SageMaker test
|
||||||
"tests/sagemaker/test_multi_node_model_parallel.py", # SageMaker test
|
"tests/sagemaker/test_multi_node_model_parallel.py", # SageMaker test
|
||||||
"tests/sagemaker/test_multi_node_data_parallel.py", # SageMaker test
|
"tests/sagemaker/test_multi_node_data_parallel.py", # SageMaker test
|
||||||
"tests/mixed_int8/test_mixed_int8.py", # Mixed-int8 bitsandbytes test
|
"tests/utils/test_doc_samples.py", # Doc tests
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user