Generation: deprecate PreTrainedModel inheriting from GenerationMixin (#33203)

This commit is contained in:
Joao Gante
2024-09-23 18:28:36 +01:00
committed by GitHub
parent 1456120929
commit e15687fffe
126 changed files with 407 additions and 184 deletions

View File

@@ -67,6 +67,7 @@ if is_torch_available():
BertModel,
FunnelBaseModel,
FunnelModel,
GenerationMixin,
GPT2Config,
GPT2LMHeadModel,
ResNetBackbone,
@@ -571,3 +572,20 @@ class AutoModelTest(unittest.TestCase):
_ = AutoModelForCausalLM.from_pretrained(tmp_dir_out, trust_remote_code=True)
self.assertTrue((Path(tmp_dir_out) / "modeling_fake_custom.py").is_file())
self.assertTrue((Path(tmp_dir_out) / "configuration_fake_custom.py").is_file())
def test_custom_model_patched_generation_inheritance(self):
"""
Tests that our inheritance patching for generate-compatible models works as expected. Without this feature,
old Hub models lose the ability to call `generate`.
"""
model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/test_dynamic_model_generation", trust_remote_code=True
)
self.assertTrue(model.__class__.__name__ == "NewModelForCausalLM")
# It inherits from GenerationMixin. This means it can `generate`. Because `PreTrainedModel` is scheduled to
# stop inheriting from `GenerationMixin` in v4.50, this check will fail if patching is not present.
self.assertTrue(isinstance(model, GenerationMixin))
# More precisely, it directly inherits from GenerationMixin. This check would fail prior to v4.45 (inheritance
# patching was added in v4.45)
self.assertTrue("GenerationMixin" in str(model.__class__.__bases__))