[CI] Check test if the GenerationTesterMixin inheritance is correct 🐛 🔫 (#36180)
This commit is contained in:
@@ -106,6 +106,8 @@ from transformers.utils import (
|
||||
)
|
||||
from transformers.utils.generic import ContextManagers
|
||||
|
||||
from .generation.test_utils import GenerationTesterMixin
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.utils import compute_module_sizes
|
||||
@@ -4417,6 +4419,33 @@ class ModelTesterMixin:
|
||||
# If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605)
|
||||
_ = model(inputs_dict["input_ids"].to(torch_device))
|
||||
|
||||
def test_generation_tester_mixin_inheritance(self):
|
||||
"""
|
||||
Ensures that we have the generation tester mixin if the model can generate. The test will fail otherwise,
|
||||
forcing the mixin to be added -- and ensuring proper test coverage
|
||||
"""
|
||||
if len(self.all_generative_model_classes) > 0:
|
||||
self.assertTrue(
|
||||
issubclass(self.__class__, GenerationTesterMixin),
|
||||
msg=(
|
||||
"This model can call `generate` from `GenerationMixin`, so one of two things must happen: 1) the "
|
||||
"tester must inherit from `GenerationTesterMixin` to run `generate` tests, or 2) if the model "
|
||||
"doesn't fully support the original `generate` or has a custom `generate` with partial feature "
|
||||
"support, the tester must overwrite `all_generative_model_classes` to skip the failing classes "
|
||||
"(make sure to comment why). If `all_generative_model_classes` is overwritten as `()`, then we "
|
||||
"need to remove the `GenerationTesterMixin` inheritance -- no `generate` tests are being run."
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.assertFalse(
|
||||
issubclass(self.__class__, GenerationTesterMixin),
|
||||
msg=(
|
||||
"This model can't call `generate`, so its tester can't inherit `GenerationTesterMixin`. (If you "
|
||||
"think the model should be able to `generate`, the model may be missing the `GenerationMixin` "
|
||||
"inheritance)"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user