Generation: deprecate PreTrainedModel inheriting from GenerationMixin (#33203)

This commit is contained in:
Joao Gante
2024-09-23 18:28:36 +01:00
committed by GitHub
parent 1456120929
commit e15687fffe
126 changed files with 407 additions and 184 deletions

View File

@@ -2099,6 +2099,15 @@ class GenerationTesterMixin:
)
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
@pytest.mark.generate
def test_inherits_generation_mixin(self):
"""
Tests that the model class directly inherits `GenerationMixin`, as opposed to relying on `PreTrainedModel`
to inherit it.
"""
for model_class in self.all_generative_model_classes:
self.assertTrue("GenerationMixin" in str(model_class.__bases__))
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, seq_length = input_ids.shape
config = config.text_config if hasattr(config, "text_config") else config

View File

@@ -67,6 +67,7 @@ if is_torch_available():
BertModel,
FunnelBaseModel,
FunnelModel,
GenerationMixin,
GPT2Config,
GPT2LMHeadModel,
ResNetBackbone,
@@ -571,3 +572,20 @@ class AutoModelTest(unittest.TestCase):
_ = AutoModelForCausalLM.from_pretrained(tmp_dir_out, trust_remote_code=True)
self.assertTrue((Path(tmp_dir_out) / "modeling_fake_custom.py").is_file())
self.assertTrue((Path(tmp_dir_out) / "configuration_fake_custom.py").is_file())
def test_custom_model_patched_generation_inheritance(self):
"""
Tests that our inheritance patching for generate-compatible models works as expected. Without this feature,
old Hub models lose the ability to call `generate`.
"""
model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/test_dynamic_model_generation", trust_remote_code=True
)
self.assertTrue(model.__class__.__name__ == "NewModelForCausalLM")
# It inherits from GenerationMixin. This means it can `generate`. Because `PreTrainedModel` is scheduled to
# stop inheriting from `GenerationMixin` in v4.50, this check will fail if patching is not present.
self.assertTrue(isinstance(model, GenerationMixin))
# More precisely, it directly inherits from GenerationMixin. This check would fail prior to v4.45 (inheritance
# patching was added in v4.45)
self.assertTrue("GenerationMixin" in str(model.__class__.__bases__))

View File

@@ -90,6 +90,7 @@ if is_torch_available():
BertConfig,
BertModel,
CLIPTextModel,
GenerationMixin,
PreTrainedModel,
T5Config,
T5ForConditionalGeneration,
@@ -1715,6 +1716,32 @@ class ModelUtilsTest(TestCasePlus):
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
)
def test_can_generate(self):
"""Tests the behavior of `PreTrainedModel.can_generate` method."""
# 1 - By default, a model CAN'T generate
self.assertFalse(BertModel.can_generate())
# 2 - The most common case for a model to be able to generate is to inherit from `GenerationMixin` directly
class DummyBertWithMixin(BertModel, GenerationMixin):
pass
self.assertTrue(DummyBertWithMixin.can_generate())
# 3 - Alternatively, a model can implement a `generate` method
class DummyBertWithGenerate(BertModel):
def generate(self):
pass
self.assertTrue(DummyBertWithGenerate.can_generate())
# 4 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited
# `GenerationMixin`)
class DummyBertWithPrepareInputs(BertModel):
def prepare_inputs_for_generation(self):
pass
self.assertTrue(DummyBertWithPrepareInputs.can_generate())
def test_save_and_load_config_with_custom_generation(self):
"""
Regression test for the ability to save and load a config with a custom generation kwarg (i.e. a parameter