Generation: deprecate PreTrainedModel inheriting from GenerationMixin (#33203)
This commit is contained in:
@@ -2099,6 +2099,15 @@ class GenerationTesterMixin:
|
||||
)
|
||||
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_inherits_generation_mixin(self):
|
||||
"""
|
||||
Tests that the model class directly inherits `GenerationMixin`, as opposed to relying on `PreTrainedModel`
|
||||
to inherit it.
|
||||
"""
|
||||
for model_class in self.all_generative_model_classes:
|
||||
self.assertTrue("GenerationMixin" in str(model_class.__bases__))
|
||||
|
||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
config = config.text_config if hasattr(config, "text_config") else config
|
||||
|
||||
Reference in New Issue
Block a user