Template for framework-agnostic tests (#21348)
This commit is contained in:
41
tests/generation/test_framework_agnostic.py
Normal file
41
tests/generation/test_framework_agnostic.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
Framework agnostic tests for generate()-related methods.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
class GenerationIntegrationTestsMixin:
|
||||
|
||||
# To be populated by the child classes
|
||||
framework_dependent_parameters = {
|
||||
"AutoModelForSeq2SeqLM": None,
|
||||
"create_tensor_fn": None,
|
||||
"return_tensors": None,
|
||||
}
|
||||
|
||||
def test_validate_generation_inputs(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
|
||||
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||
create_tensor_fn = self.framework_dependent_parameters["create_tensor_fn"]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
encoder_input_str = "Hello world"
|
||||
input_ids = tokenizer(encoder_input_str, return_tensors=return_tensors).input_ids
|
||||
|
||||
# typos are quickly detected (the correct argument is `do_sample`)
|
||||
with self.assertRaisesRegex(ValueError, "do_samples"):
|
||||
model.generate(input_ids, do_samples=True)
|
||||
|
||||
# arbitrary arguments that will not be used anywhere are also not accepted
|
||||
with self.assertRaisesRegex(ValueError, "foo"):
|
||||
fake_model_kwargs = {"foo": "bar"}
|
||||
model.generate(input_ids, **fake_model_kwargs)
|
||||
|
||||
# however, valid model_kwargs are accepted
|
||||
valid_model_kwargs = {"attention_mask": create_tensor_fn(np.zeros_like(input_ids))}
|
||||
model.generate(input_ids, **valid_model_kwargs)
|
||||
Reference in New Issue
Block a user