Generation: deprecate PreTrainedModel inheriting from GenerationMixin (#33203)
This commit is contained in:
@@ -2099,6 +2099,15 @@ class GenerationTesterMixin:
|
||||
)
|
||||
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_inherits_generation_mixin(self):
|
||||
"""
|
||||
Tests that the model class directly inherits `GenerationMixin`, as opposed to relying on `PreTrainedModel`
|
||||
to inherit it.
|
||||
"""
|
||||
for model_class in self.all_generative_model_classes:
|
||||
self.assertTrue("GenerationMixin" in str(model_class.__bases__))
|
||||
|
||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
config = config.text_config if hasattr(config, "text_config") else config
|
||||
|
||||
@@ -67,6 +67,7 @@ if is_torch_available():
|
||||
BertModel,
|
||||
FunnelBaseModel,
|
||||
FunnelModel,
|
||||
GenerationMixin,
|
||||
GPT2Config,
|
||||
GPT2LMHeadModel,
|
||||
ResNetBackbone,
|
||||
@@ -571,3 +572,20 @@ class AutoModelTest(unittest.TestCase):
|
||||
_ = AutoModelForCausalLM.from_pretrained(tmp_dir_out, trust_remote_code=True)
|
||||
self.assertTrue((Path(tmp_dir_out) / "modeling_fake_custom.py").is_file())
|
||||
self.assertTrue((Path(tmp_dir_out) / "configuration_fake_custom.py").is_file())
|
||||
|
||||
def test_custom_model_patched_generation_inheritance(self):
|
||||
"""
|
||||
Tests that our inheritance patching for generate-compatible models works as expected. Without this feature,
|
||||
old Hub models lose the ability to call `generate`.
|
||||
"""
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"hf-internal-testing/test_dynamic_model_generation", trust_remote_code=True
|
||||
)
|
||||
self.assertTrue(model.__class__.__name__ == "NewModelForCausalLM")
|
||||
|
||||
# It inherits from GenerationMixin. This means it can `generate`. Because `PreTrainedModel` is scheduled to
|
||||
# stop inheriting from `GenerationMixin` in v4.50, this check will fail if patching is not present.
|
||||
self.assertTrue(isinstance(model, GenerationMixin))
|
||||
# More precisely, it directly inherits from GenerationMixin. This check would fail prior to v4.45 (inheritance
|
||||
# patching was added in v4.45)
|
||||
self.assertTrue("GenerationMixin" in str(model.__class__.__bases__))
|
||||
|
||||
@@ -90,6 +90,7 @@ if is_torch_available():
|
||||
BertConfig,
|
||||
BertModel,
|
||||
CLIPTextModel,
|
||||
GenerationMixin,
|
||||
PreTrainedModel,
|
||||
T5Config,
|
||||
T5ForConditionalGeneration,
|
||||
@@ -1715,6 +1716,32 @@ class ModelUtilsTest(TestCasePlus):
|
||||
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
|
||||
)
|
||||
|
||||
def test_can_generate(self):
|
||||
"""Tests the behavior of `PreTrainedModel.can_generate` method."""
|
||||
# 1 - By default, a model CAN'T generate
|
||||
self.assertFalse(BertModel.can_generate())
|
||||
|
||||
# 2 - The most common case for a model to be able to generate is to inherit from `GenerationMixin` directly
|
||||
class DummyBertWithMixin(BertModel, GenerationMixin):
|
||||
pass
|
||||
|
||||
self.assertTrue(DummyBertWithMixin.can_generate())
|
||||
|
||||
# 3 - Alternatively, a model can implement a `generate` method
|
||||
class DummyBertWithGenerate(BertModel):
|
||||
def generate(self):
|
||||
pass
|
||||
|
||||
self.assertTrue(DummyBertWithGenerate.can_generate())
|
||||
|
||||
# 4 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited
|
||||
# `GenerationMixin`)
|
||||
class DummyBertWithPrepareInputs(BertModel):
|
||||
def prepare_inputs_for_generation(self):
|
||||
pass
|
||||
|
||||
self.assertTrue(DummyBertWithPrepareInputs.can_generate())
|
||||
|
||||
def test_save_and_load_config_with_custom_generation(self):
|
||||
"""
|
||||
Regression test for the ability to save and load a config with a custom generation kwarg (i.e. a parameter
|
||||
|
||||
Reference in New Issue
Block a user