Generation: deprecate PreTrainedModel inheriting from GenerationMixin (#33203)

This commit is contained in:
Joao Gante
2024-09-23 18:28:36 +01:00
committed by GitHub
parent 1456120929
commit e15687fffe
126 changed files with 407 additions and 184 deletions

View File

@@ -90,6 +90,7 @@ if is_torch_available():
BertConfig,
BertModel,
CLIPTextModel,
GenerationMixin,
PreTrainedModel,
T5Config,
T5ForConditionalGeneration,
@@ -1715,6 +1716,32 @@ class ModelUtilsTest(TestCasePlus):
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
)
def test_can_generate(self):
"""Tests the behavior of `PreTrainedModel.can_generate` method."""
# 1 - By default, a model CAN'T generate
self.assertFalse(BertModel.can_generate())
# 2 - The most common case for a model to be able to generate is to inherit from `GenerationMixin` directly
class DummyBertWithMixin(BertModel, GenerationMixin):
pass
self.assertTrue(DummyBertWithMixin.can_generate())
# 3 - Alternatively, a model can implement a `generate` method
class DummyBertWithGenerate(BertModel):
def generate(self):
pass
self.assertTrue(DummyBertWithGenerate.can_generate())
# 4 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited
# `GenerationMixin`)
class DummyBertWithPrepareInputs(BertModel):
def prepare_inputs_for_generation(self):
pass
self.assertTrue(DummyBertWithPrepareInputs.can_generate())
def test_save_and_load_config_with_custom_generation(self):
"""
Regression test for the ability to save and load a config with a custom generation kwarg (i.e. a parameter