Generation: deprecate PreTrainedModel inheriting from GenerationMixin (#33203)

This commit is contained in:
Joao Gante
2024-09-23 18:28:36 +01:00
committed by GitHub
parent 1456120929
commit e15687fffe
126 changed files with 407 additions and 184 deletions

View File

@@ -2099,6 +2099,15 @@ class GenerationTesterMixin:
)
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
@pytest.mark.generate
def test_inherits_generation_mixin(self):
"""
Tests that the model class directly inherits `GenerationMixin`, as opposed to relying on `PreTrainedModel`
to inherit it.
"""
for model_class in self.all_generative_model_classes:
self.assertTrue("GenerationMixin" in str(model_class.__bases__))
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, seq_length = input_ids.shape
config = config.text_config if hasattr(config, "text_config") else config