Generate: can_generate() recursive check (#33718)
* add recursive check and test warnings * missing space * models without can_generate
This commit is contained in:
@@ -1645,6 +1645,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# Model class overwrites `generate` (e.g. time series models) -> can generate
|
||||
if str(cls.__name__) in str(cls.generate):
|
||||
return True
|
||||
# The class inherits from a class that can generate (recursive check) -> can generate
|
||||
for base in cls.__bases__:
|
||||
if not hasattr(base, "can_generate"):
|
||||
continue
|
||||
if "PreTrainedModel" not in str(base) and base.can_generate():
|
||||
return True
|
||||
# BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this
|
||||
# was how we detected whether a model could generate.
|
||||
if "GenerationMixin" not in str(cls.prepare_inputs_for_generation):
|
||||
|
||||
@@ -1718,29 +1718,51 @@ class ModelUtilsTest(TestCasePlus):
|
||||
|
||||
def test_can_generate(self):
|
||||
"""Tests the behavior of `PreTrainedModel.can_generate` method."""
|
||||
logger = logging.get_logger("transformers.modeling_utils")
|
||||
logger.warning_once.cache_clear()
|
||||
|
||||
# 1 - By default, a model CAN'T generate
|
||||
self.assertFalse(BertModel.can_generate())
|
||||
can_generate = BertModel.can_generate()
|
||||
self.assertFalse(can_generate)
|
||||
|
||||
# 2 - The most common case for a model to be able to generate is to inherit from `GenerationMixin` directly
|
||||
class DummyBertWithMixin(BertModel, GenerationMixin):
|
||||
pass
|
||||
|
||||
self.assertTrue(DummyBertWithMixin.can_generate())
|
||||
with CaptureLogger(logger) as cl:
|
||||
can_generate = DummyBertWithMixin.can_generate()
|
||||
self.assertTrue("" == cl.out)
|
||||
self.assertTrue(can_generate)
|
||||
|
||||
# 3 - Alternatively, a model can implement a `generate` method
|
||||
class DummyBertWithGenerate(BertModel):
|
||||
def generate(self):
|
||||
pass
|
||||
|
||||
self.assertTrue(DummyBertWithGenerate.can_generate())
|
||||
with CaptureLogger(logger) as cl:
|
||||
can_generate = DummyBertWithGenerate.can_generate()
|
||||
self.assertTrue("" == cl.out)
|
||||
self.assertTrue(can_generate)
|
||||
|
||||
# 4 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited
|
||||
# 4 - Finally, it can inherit from a model that can generate
|
||||
class DummyBertWithParent(DummyBertWithMixin):
|
||||
pass
|
||||
|
||||
with CaptureLogger(logger) as cl:
|
||||
can_generate = DummyBertWithParent.can_generate()
|
||||
self.assertTrue("" == cl.out)
|
||||
self.assertTrue(can_generate)
|
||||
|
||||
# 5 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited
|
||||
# `GenerationMixin`)
|
||||
class DummyBertWithPrepareInputs(BertModel):
|
||||
def prepare_inputs_for_generation(self):
|
||||
pass
|
||||
|
||||
self.assertTrue(DummyBertWithPrepareInputs.can_generate())
|
||||
with CaptureLogger(logger) as cl:
|
||||
can_generate = DummyBertWithPrepareInputs.can_generate()
|
||||
self.assertTrue("it doesn't directly inherit from `GenerationMixin`" in cl.out)
|
||||
self.assertTrue(can_generate)
|
||||
|
||||
def test_save_and_load_config_with_custom_generation(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user