Generation: deprecate PreTrainedModel inheriting from GenerationMixin (#33203)
This commit is contained in:
@@ -90,6 +90,7 @@ if is_torch_available():
|
||||
BertConfig,
|
||||
BertModel,
|
||||
CLIPTextModel,
|
||||
GenerationMixin,
|
||||
PreTrainedModel,
|
||||
T5Config,
|
||||
T5ForConditionalGeneration,
|
||||
@@ -1715,6 +1716,32 @@ class ModelUtilsTest(TestCasePlus):
|
||||
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
|
||||
)
|
||||
|
||||
def test_can_generate(self):
|
||||
"""Tests the behavior of `PreTrainedModel.can_generate` method."""
|
||||
# 1 - By default, a model CAN'T generate
|
||||
self.assertFalse(BertModel.can_generate())
|
||||
|
||||
# 2 - The most common case for a model to be able to generate is to inherit from `GenerationMixin` directly
|
||||
class DummyBertWithMixin(BertModel, GenerationMixin):
|
||||
pass
|
||||
|
||||
self.assertTrue(DummyBertWithMixin.can_generate())
|
||||
|
||||
# 3 - Alternatively, a model can implement a `generate` method
|
||||
class DummyBertWithGenerate(BertModel):
|
||||
def generate(self):
|
||||
pass
|
||||
|
||||
self.assertTrue(DummyBertWithGenerate.can_generate())
|
||||
|
||||
# 4 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited
|
||||
# `GenerationMixin`)
|
||||
class DummyBertWithPrepareInputs(BertModel):
|
||||
def prepare_inputs_for_generation(self):
|
||||
pass
|
||||
|
||||
self.assertTrue(DummyBertWithPrepareInputs.can_generate())
|
||||
|
||||
def test_save_and_load_config_with_custom_generation(self):
|
||||
"""
|
||||
Regression test for the ability to save and load a config with a custom generation kwarg (i.e. a parameter
|
||||
|
||||
Reference in New Issue
Block a user