diff --git a/tests/generation/test_framework_agnostic.py b/tests/generation/test_framework_agnostic.py new file mode 100644 index 0000000000..31cc78d411 --- /dev/null +++ b/tests/generation/test_framework_agnostic.py @@ -0,0 +1,41 @@ +""" +Framework agnostic tests for generate()-related methods. +""" + +import numpy as np + +from transformers import AutoTokenizer + + +class GenerationIntegrationTestsMixin: + + # To be populated by the child classes + framework_dependent_parameters = { + "AutoModelForSeq2SeqLM": None, + "create_tensor_fn": None, + "return_tensors": None, + } + + def test_validate_generation_inputs(self): + model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"] + return_tensors = self.framework_dependent_parameters["return_tensors"] + create_tensor_fn = self.framework_dependent_parameters["create_tensor_fn"] + + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + model = model_cls.from_pretrained("hf-internal-testing/tiny-random-t5") + + encoder_input_str = "Hello world" + input_ids = tokenizer(encoder_input_str, return_tensors=return_tensors).input_ids + + # typos are quickly detected (the correct argument is `do_sample`) + with self.assertRaisesRegex(ValueError, "do_samples"): + model.generate(input_ids, do_samples=True) + + # arbitrary arguments that will not be used anywhere are also not accepted + with self.assertRaisesRegex(ValueError, "foo"): + fake_model_kwargs = {"foo": "bar"} + model.generate(input_ids, **fake_model_kwargs) + + # however, valid model_kwargs are accepted + valid_model_kwargs = {"attention_mask": create_tensor_fn(np.zeros_like(input_ids))} + model.generate(input_ids, **valid_model_kwargs) diff --git a/tests/generation/test_tf_utils.py b/tests/generation/test_tf_utils.py index d0d284182b..42eac59e50 100644 --- a/tests/generation/test_tf_utils.py +++ b/tests/generation/test_tf_utils.py @@ -19,11 +19,13 @@ import unittest from transformers import is_tf_available from transformers.testing_utils import require_tf, slow +from .test_framework_agnostic import GenerationIntegrationTestsMixin + if is_tf_available(): import tensorflow as tf - from transformers import AutoTokenizer, TFAutoModelForCausalLM, TFAutoModelForSeq2SeqLM, tf_top_k_top_p_filtering + from transformers import TFAutoModelForCausalLM, TFAutoModelForSeq2SeqLM, tf_top_k_top_p_filtering @require_tf @@ -124,7 +126,16 @@ class UtilsFunctionsTest(unittest.TestCase): @require_tf -class TFGenerationIntegrationTests(unittest.TestCase): +class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMixin): + + # setting framework_dependent_parameters needs to be gated, just like its contents' imports + if is_tf_available(): + framework_dependent_parameters = { + "AutoModelForSeq2SeqLM": TFAutoModelForSeq2SeqLM, + "create_tensor_fn": tf.convert_to_tensor, + "return_tensors": "tf", + } + @slow def test_generate_tf_function_export(self): test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2") @@ -165,19 +176,3 @@ class TFGenerationIntegrationTests(unittest.TestCase): tf_func_outputs = serving_func(**inputs)["sequences"] tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_length) tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs) - - def test_validate_generation_inputs(self): - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") - model = TFAutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5") - - encoder_input_str = "Hello world" - input_ids = tokenizer(encoder_input_str, return_tensors="tf").input_ids - - # typos are quickly detected (the correct argument is `do_sample`) - with self.assertRaisesRegex(ValueError, "do_samples"): - model.generate(input_ids, do_samples=True) - - # arbitrary arguments that will not be used anywhere are also not accepted - with self.assertRaisesRegex(ValueError, "foo"): - fake_model_kwargs = {"foo": "bar"} - model.generate(input_ids, **fake_model_kwargs) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 5a5e578ee4..1546d14b43 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -23,6 +23,7 @@ from transformers import is_torch_available, pipeline from transformers.testing_utils import require_torch, slow, torch_device from ..test_modeling_common import floats_tensor, ids_tensor +from .test_framework_agnostic import GenerationIntegrationTestsMixin if is_torch_available(): @@ -1790,7 +1791,16 @@ class UtilsFunctionsTest(unittest.TestCase): @require_torch -class GenerationIntegrationTests(unittest.TestCase): +class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMixin): + + # setting framework_dependent_parameters needs to be gated, just like its contents' imports + if is_torch_available(): + framework_dependent_parameters = { + "AutoModelForSeq2SeqLM": AutoModelForSeq2SeqLM, + "create_tensor_fn": torch.tensor, + "return_tensors": "pt", + } + @slow def test_diverse_beam_search(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood. @@ -3022,26 +3032,6 @@ class GenerationIntegrationTests(unittest.TestCase): max_score_diff = (output_sequences_batched.scores[0][1] - output_sequences.scores[0][0]).abs().max() self.assertTrue(max_score_diff < 1e-5) - def test_validate_generation_inputs(self): - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta") - model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-roberta") - - encoder_input_str = "Hello world" - input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids - - # typos are quickly detected (the correct argument is `do_sample`) - with self.assertRaisesRegex(ValueError, "do_samples"): - model.generate(input_ids, do_samples=True) - - # arbitrary arguments that will not be used anywhere are also not accepted - with self.assertRaisesRegex(ValueError, "foo"): - fake_model_kwargs = {"foo": "bar"} - model.generate(input_ids, **fake_model_kwargs) - - # However, valid model_kwargs are accepted - valid_model_kwargs = {"attention_mask": torch.zeros_like(input_ids)} - model.generate(input_ids, **valid_model_kwargs) - def test_eos_token_id_int_and_list_greedy_search(self): generation_kwargs = { "do_sample": False, diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py index d388c11361..2b28896a5d 100644 --- a/utils/tests_fetcher.py +++ b/utils/tests_fetcher.py @@ -466,12 +466,13 @@ def module_to_test_file(module_fname): # This list contains the list of test files we expect never to be launched from a change in a module/util. Those are # launched separately. EXPECTED_TEST_FILES_NEVER_TOUCHED = [ - "tests/utils/test_doc_samples.py", # Doc tests + "tests/generation/test_framework_agnostic.py", # Mixins inherited by actual test classes + "tests/mixed_int8/test_mixed_int8.py", # Mixed-int8 bitsandbytes test "tests/pipelines/test_pipelines_common.py", # Actually checked by the pipeline based file "tests/sagemaker/test_single_node_gpu.py", # SageMaker test "tests/sagemaker/test_multi_node_model_parallel.py", # SageMaker test "tests/sagemaker/test_multi_node_data_parallel.py", # SageMaker test - "tests/mixed_int8/test_mixed_int8.py", # Mixed-int8 bitsandbytes test + "tests/utils/test_doc_samples.py", # Doc tests ]