Generate: can_generate() recursive check (#33718)

* add recursive check and test warnings

* missing space

* models without can_generate
This commit is contained in:
Joao Gante
2024-09-26 18:11:14 +01:00
committed by GitHub
parent 9f97c39384
commit 3557f9a14a
2 changed files with 33 additions and 5 deletions

View File

@@ -1645,6 +1645,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Model class overwrites `generate` (e.g. time series models) -> can generate # Model class overwrites `generate` (e.g. time series models) -> can generate
if str(cls.__name__) in str(cls.generate): if str(cls.__name__) in str(cls.generate):
return True return True
# The class inherits from a class that can generate (recursive check) -> can generate
for base in cls.__bases__:
if not hasattr(base, "can_generate"):
continue
if "PreTrainedModel" not in str(base) and base.can_generate():
return True
# BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this # BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this
# was how we detected whether a model could generate. # was how we detected whether a model could generate.
if "GenerationMixin" not in str(cls.prepare_inputs_for_generation): if "GenerationMixin" not in str(cls.prepare_inputs_for_generation):

View File

@@ -1718,29 +1718,51 @@ class ModelUtilsTest(TestCasePlus):
def test_can_generate(self): def test_can_generate(self):
"""Tests the behavior of `PreTrainedModel.can_generate` method.""" """Tests the behavior of `PreTrainedModel.can_generate` method."""
logger = logging.get_logger("transformers.modeling_utils")
logger.warning_once.cache_clear()
# 1 - By default, a model CAN'T generate # 1 - By default, a model CAN'T generate
self.assertFalse(BertModel.can_generate()) can_generate = BertModel.can_generate()
self.assertFalse(can_generate)
# 2 - The most common case for a model to be able to generate is to inherit from `GenerationMixin` directly # 2 - The most common case for a model to be able to generate is to inherit from `GenerationMixin` directly
class DummyBertWithMixin(BertModel, GenerationMixin): class DummyBertWithMixin(BertModel, GenerationMixin):
pass pass
self.assertTrue(DummyBertWithMixin.can_generate()) with CaptureLogger(logger) as cl:
can_generate = DummyBertWithMixin.can_generate()
self.assertTrue("" == cl.out)
self.assertTrue(can_generate)
# 3 - Alternatively, a model can implement a `generate` method # 3 - Alternatively, a model can implement a `generate` method
class DummyBertWithGenerate(BertModel): class DummyBertWithGenerate(BertModel):
def generate(self): def generate(self):
pass pass
self.assertTrue(DummyBertWithGenerate.can_generate()) with CaptureLogger(logger) as cl:
can_generate = DummyBertWithGenerate.can_generate()
self.assertTrue("" == cl.out)
self.assertTrue(can_generate)
# 4 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited # 4 - Finally, it can inherit from a model that can generate
class DummyBertWithParent(DummyBertWithMixin):
pass
with CaptureLogger(logger) as cl:
can_generate = DummyBertWithParent.can_generate()
self.assertTrue("" == cl.out)
self.assertTrue(can_generate)
# 5 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited
# `GenerationMixin`) # `GenerationMixin`)
class DummyBertWithPrepareInputs(BertModel): class DummyBertWithPrepareInputs(BertModel):
def prepare_inputs_for_generation(self): def prepare_inputs_for_generation(self):
pass pass
self.assertTrue(DummyBertWithPrepareInputs.can_generate()) with CaptureLogger(logger) as cl:
can_generate = DummyBertWithPrepareInputs.can_generate()
self.assertTrue("it doesn't directly inherit from `GenerationMixin`" in cl.out)
self.assertTrue(can_generate)
def test_save_and_load_config_with_custom_generation(self): def test_save_and_load_config_with_custom_generation(self):
""" """