Generation: deprecate PreTrainedModel inheriting from GenerationMixin (#33203)
This commit is contained in:
@@ -67,6 +67,7 @@ if is_torch_available():
|
||||
BertModel,
|
||||
FunnelBaseModel,
|
||||
FunnelModel,
|
||||
GenerationMixin,
|
||||
GPT2Config,
|
||||
GPT2LMHeadModel,
|
||||
ResNetBackbone,
|
||||
@@ -571,3 +572,20 @@ class AutoModelTest(unittest.TestCase):
|
||||
_ = AutoModelForCausalLM.from_pretrained(tmp_dir_out, trust_remote_code=True)
|
||||
self.assertTrue((Path(tmp_dir_out) / "modeling_fake_custom.py").is_file())
|
||||
self.assertTrue((Path(tmp_dir_out) / "configuration_fake_custom.py").is_file())
|
||||
|
||||
def test_custom_model_patched_generation_inheritance(self):
|
||||
"""
|
||||
Tests that our inheritance patching for generate-compatible models works as expected. Without this feature,
|
||||
old Hub models lose the ability to call `generate`.
|
||||
"""
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"hf-internal-testing/test_dynamic_model_generation", trust_remote_code=True
|
||||
)
|
||||
self.assertTrue(model.__class__.__name__ == "NewModelForCausalLM")
|
||||
|
||||
# It inherits from GenerationMixin. This means it can `generate`. Because `PreTrainedModel` is scheduled to
|
||||
# stop inheriting from `GenerationMixin` in v4.50, this check will fail if patching is not present.
|
||||
self.assertTrue(isinstance(model, GenerationMixin))
|
||||
# More precisely, it directly inherits from GenerationMixin. This check would fail prior to v4.45 (inheritance
|
||||
# patching was added in v4.45)
|
||||
self.assertTrue("GenerationMixin" in str(model.__class__.__bases__))
|
||||
|
||||
Reference in New Issue
Block a user