Generate: can_generate() recursive check (#33718)
* add recursive check and test warnings * missing space * models without can_generate
This commit is contained in:
@@ -1645,6 +1645,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# Model class overwrites `generate` (e.g. time series models) -> can generate
|
# Model class overwrites `generate` (e.g. time series models) -> can generate
|
||||||
if str(cls.__name__) in str(cls.generate):
|
if str(cls.__name__) in str(cls.generate):
|
||||||
return True
|
return True
|
||||||
|
# The class inherits from a class that can generate (recursive check) -> can generate
|
||||||
|
for base in cls.__bases__:
|
||||||
|
if not hasattr(base, "can_generate"):
|
||||||
|
continue
|
||||||
|
if "PreTrainedModel" not in str(base) and base.can_generate():
|
||||||
|
return True
|
||||||
# BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this
|
# BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this
|
||||||
# was how we detected whether a model could generate.
|
# was how we detected whether a model could generate.
|
||||||
if "GenerationMixin" not in str(cls.prepare_inputs_for_generation):
|
if "GenerationMixin" not in str(cls.prepare_inputs_for_generation):
|
||||||
|
|||||||
@@ -1718,29 +1718,51 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
|
|
||||||
def test_can_generate(self):
|
def test_can_generate(self):
|
||||||
"""Tests the behavior of `PreTrainedModel.can_generate` method."""
|
"""Tests the behavior of `PreTrainedModel.can_generate` method."""
|
||||||
|
logger = logging.get_logger("transformers.modeling_utils")
|
||||||
|
logger.warning_once.cache_clear()
|
||||||
|
|
||||||
# 1 - By default, a model CAN'T generate
|
# 1 - By default, a model CAN'T generate
|
||||||
self.assertFalse(BertModel.can_generate())
|
can_generate = BertModel.can_generate()
|
||||||
|
self.assertFalse(can_generate)
|
||||||
|
|
||||||
# 2 - The most common case for a model to be able to generate is to inherit from `GenerationMixin` directly
|
# 2 - The most common case for a model to be able to generate is to inherit from `GenerationMixin` directly
|
||||||
class DummyBertWithMixin(BertModel, GenerationMixin):
|
class DummyBertWithMixin(BertModel, GenerationMixin):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
self.assertTrue(DummyBertWithMixin.can_generate())
|
with CaptureLogger(logger) as cl:
|
||||||
|
can_generate = DummyBertWithMixin.can_generate()
|
||||||
|
self.assertTrue("" == cl.out)
|
||||||
|
self.assertTrue(can_generate)
|
||||||
|
|
||||||
# 3 - Alternatively, a model can implement a `generate` method
|
# 3 - Alternatively, a model can implement a `generate` method
|
||||||
class DummyBertWithGenerate(BertModel):
|
class DummyBertWithGenerate(BertModel):
|
||||||
def generate(self):
|
def generate(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
self.assertTrue(DummyBertWithGenerate.can_generate())
|
with CaptureLogger(logger) as cl:
|
||||||
|
can_generate = DummyBertWithGenerate.can_generate()
|
||||||
|
self.assertTrue("" == cl.out)
|
||||||
|
self.assertTrue(can_generate)
|
||||||
|
|
||||||
# 4 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited
|
# 4 - Finally, it can inherit from a model that can generate
|
||||||
|
class DummyBertWithParent(DummyBertWithMixin):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
|
can_generate = DummyBertWithParent.can_generate()
|
||||||
|
self.assertTrue("" == cl.out)
|
||||||
|
self.assertTrue(can_generate)
|
||||||
|
|
||||||
|
# 5 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited
|
||||||
# `GenerationMixin`)
|
# `GenerationMixin`)
|
||||||
class DummyBertWithPrepareInputs(BertModel):
|
class DummyBertWithPrepareInputs(BertModel):
|
||||||
def prepare_inputs_for_generation(self):
|
def prepare_inputs_for_generation(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
self.assertTrue(DummyBertWithPrepareInputs.can_generate())
|
with CaptureLogger(logger) as cl:
|
||||||
|
can_generate = DummyBertWithPrepareInputs.can_generate()
|
||||||
|
self.assertTrue("it doesn't directly inherit from `GenerationMixin`" in cl.out)
|
||||||
|
self.assertTrue(can_generate)
|
||||||
|
|
||||||
def test_save_and_load_config_with_custom_generation(self):
|
def test_save_and_load_config_with_custom_generation(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user