[CI] Check test if the GenerationTesterMixin inheritance is correct 🐛 🔫 (#36180)

This commit is contained in:
Joao Gante
2025-02-21 10:18:20 +00:00
committed by GitHub
parent a957b7911a
commit 678885bbbd
39 changed files with 180 additions and 68 deletions

View File

@@ -106,6 +106,8 @@ from transformers.utils import (
)
from transformers.utils.generic import ContextManagers
from .generation.test_utils import GenerationTesterMixin
if is_accelerate_available():
from accelerate.utils import compute_module_sizes
@@ -4417,6 +4419,33 @@ class ModelTesterMixin:
# If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605)
_ = model(inputs_dict["input_ids"].to(torch_device))
def test_generation_tester_mixin_inheritance(self):
"""
Ensures that we have the generation tester mixin if the model can generate. The test will fail otherwise,
forcing the mixin to be added -- and ensuring proper test coverage
"""
if len(self.all_generative_model_classes) > 0:
self.assertTrue(
issubclass(self.__class__, GenerationTesterMixin),
msg=(
"This model can call `generate` from `GenerationMixin`, so one of two things must happen: 1) the "
"tester must inherit from `GenerationTesterMixin` to run `generate` tests, or 2) if the model "
"doesn't fully support the original `generate` or has a custom `generate` with partial feature "
"support, the tester must overwrite `all_generative_model_classes` to skip the failing classes "
"(make sure to comment why). If `all_generative_model_classes` is overwritten as `()`, then we "
"need to remove the `GenerationTesterMixin` inheritance -- no `generate` tests are being run."
),
)
else:
self.assertFalse(
issubclass(self.__class__, GenerationTesterMixin),
msg=(
"This model can't call `generate`, so its tester can't inherit `GenerationTesterMixin`. (If you "
"think the model should be able to `generate`, the model may be missing the `GenerationMixin` "
"inheritance)"
),
)
global_rng = random.Random()