From 3557f9a14a920caeb489f637ac2f6d1a7559df04 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 26 Sep 2024 18:11:14 +0100 Subject: [PATCH] Generate: `can_generate()` recursive check (#33718) * add recursive check and test warnings * missing space * models without can_generate --- src/transformers/modeling_utils.py | 6 ++++++ tests/utils/test_modeling_utils.py | 32 +++++++++++++++++++++++++----- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3e3d789087..d0f4239c38 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1645,6 +1645,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Model class overwrites `generate` (e.g. time series models) -> can generate if str(cls.__name__) in str(cls.generate): return True + # The class inherits from a class that can generate (recursive check) -> can generate + for base in cls.__bases__: + if not hasattr(base, "can_generate"): + continue + if "PreTrainedModel" not in str(base) and base.can_generate(): + return True # BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this # was how we detected whether a model could generate. if "GenerationMixin" not in str(cls.prepare_inputs_for_generation): diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 5155647059..3317a47d75 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1718,29 +1718,51 @@ class ModelUtilsTest(TestCasePlus): def test_can_generate(self): """Tests the behavior of `PreTrainedModel.can_generate` method.""" + logger = logging.get_logger("transformers.modeling_utils") + logger.warning_once.cache_clear() + # 1 - By default, a model CAN'T generate - self.assertFalse(BertModel.can_generate()) + can_generate = BertModel.can_generate() + self.assertFalse(can_generate) # 2 - The most common case for a model to be able to generate is to inherit from `GenerationMixin` directly class DummyBertWithMixin(BertModel, GenerationMixin): pass - self.assertTrue(DummyBertWithMixin.can_generate()) + with CaptureLogger(logger) as cl: + can_generate = DummyBertWithMixin.can_generate() + self.assertTrue("" == cl.out) + self.assertTrue(can_generate) # 3 - Alternatively, a model can implement a `generate` method class DummyBertWithGenerate(BertModel): def generate(self): pass - self.assertTrue(DummyBertWithGenerate.can_generate()) + with CaptureLogger(logger) as cl: + can_generate = DummyBertWithGenerate.can_generate() + self.assertTrue("" == cl.out) + self.assertTrue(can_generate) - # 4 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited + # 4 - Finally, it can inherit from a model that can generate + class DummyBertWithParent(DummyBertWithMixin): + pass + + with CaptureLogger(logger) as cl: + can_generate = DummyBertWithParent.can_generate() + self.assertTrue("" == cl.out) + self.assertTrue(can_generate) + + # 5 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited # `GenerationMixin`) class DummyBertWithPrepareInputs(BertModel): def prepare_inputs_for_generation(self): pass - self.assertTrue(DummyBertWithPrepareInputs.can_generate()) + with CaptureLogger(logger) as cl: + can_generate = DummyBertWithPrepareInputs.can_generate() + self.assertTrue("it doesn't directly inherit from `GenerationMixin`" in cl.out) + self.assertTrue(can_generate) def test_save_and_load_config_with_custom_generation(self): """