Template for framework-agnostic tests (#21348)

This commit is contained in:
Joao Gante
2023-01-31 11:33:18 +00:00
committed by GitHub
parent 5451f8896c
commit 623346ab18
4 changed files with 68 additions and 41 deletions

View File

@@ -23,6 +23,7 @@ from transformers import is_torch_available, pipeline
from transformers.testing_utils import require_torch, slow, torch_device
from ..test_modeling_common import floats_tensor, ids_tensor
from .test_framework_agnostic import GenerationIntegrationTestsMixin
if is_torch_available():
@@ -1790,7 +1791,16 @@ class UtilsFunctionsTest(unittest.TestCase):
@require_torch
class GenerationIntegrationTests(unittest.TestCase):
class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMixin):
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
if is_torch_available():
framework_dependent_parameters = {
"AutoModelForSeq2SeqLM": AutoModelForSeq2SeqLM,
"create_tensor_fn": torch.tensor,
"return_tensors": "pt",
}
@slow
def test_diverse_beam_search(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.
@@ -3022,26 +3032,6 @@ class GenerationIntegrationTests(unittest.TestCase):
max_score_diff = (output_sequences_batched.scores[0][1] - output_sequences.scores[0][0]).abs().max()
self.assertTrue(max_score_diff < 1e-5)
def test_validate_generation_inputs(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta")
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-roberta")
encoder_input_str = "Hello world"
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
# typos are quickly detected (the correct argument is `do_sample`)
with self.assertRaisesRegex(ValueError, "do_samples"):
model.generate(input_ids, do_samples=True)
# arbitrary arguments that will not be used anywhere are also not accepted
with self.assertRaisesRegex(ValueError, "foo"):
fake_model_kwargs = {"foo": "bar"}
model.generate(input_ids, **fake_model_kwargs)
# However, valid model_kwargs are accepted
valid_model_kwargs = {"attention_mask": torch.zeros_like(input_ids)}
model.generate(input_ids, **valid_model_kwargs)
def test_eos_token_id_int_and_list_greedy_search(self):
generation_kwargs = {
"do_sample": False,