From 2256875a77df7819499d1cf1076e44af048f1439 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 17 Mar 2025 21:56:18 +0800 Subject: [PATCH] fix can_generate (#36570) * fix can_generate Signed-off-by: jiqing-feng * fix can generate for speecht5 and blip Signed-off-by: jiqing-feng * fix speecht5 tests Signed-off-by: jiqing-feng * fix Signed-off-by: jiqing-feng --------- Signed-off-by: jiqing-feng Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> --- src/transformers/models/blip/modeling_blip.py | 2 +- src/transformers/models/speecht5/modeling_speecht5.py | 7 +++++++ tests/models/bark/test_modeling_bark.py | 4 ++++ tests/models/speecht5/test_modeling_speecht5.py | 7 +++++++ 4 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index 154d6eb9f8..2a7430b937 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -1233,7 +1233,7 @@ class BlipForConditionalGeneration(BlipPreTrainedModel, GenerationMixin): """, BLIP_START_DOCSTRING, ) -class BlipForQuestionAnswering(BlipPreTrainedModel): +class BlipForQuestionAnswering(BlipPreTrainedModel, GenerationMixin): config_class = BlipConfig _tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"] diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 81cca3bc84..39422cd495 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -2631,6 +2631,13 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel): # Initialize weights and apply final processing self.post_init() + @classmethod + def can_generate(cls) -> bool: + # Speecht5 has a unique model structure, where the external class (`SpeechT5ForTextToSpeech`) doesn't need to inherit from + # `GenerationMixin` (it has a non-standard generation method). This means that the base `can_generate()` will return `False`, + # but we need to override it so as to do `GenerationConfig` handling in multiple parts of the codebase. + return True + def get_encoder(self): return self.speecht5.get_encoder() diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index d94f6d26d6..b44c3d3f04 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -1076,6 +1076,10 @@ class BarkModelIntegrationTests(unittest.TestCase): fine_generation_config = BarkFineGenerationConfig(**self.model.generation_config.fine_acoustics_config) return fine_generation_config + def test_model_can_generate(self): + # Bark has custom generate without inheriting GenerationMixin. This test could prevent regression. + self.assertTrue(self.model.can_generate()) + @slow def test_generate_semantic(self): input_ids = self.inputs diff --git a/tests/models/speecht5/test_modeling_speecht5.py b/tests/models/speecht5/test_modeling_speecht5.py index 3c316dfee2..126edf6281 100644 --- a/tests/models/speecht5/test_modeling_speecht5.py +++ b/tests/models/speecht5/test_modeling_speecht5.py @@ -881,6 +881,7 @@ class SpeechT5ForTextToSpeechTester: @require_torch class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (SpeechT5ForTextToSpeech,) if is_torch_available() else () + all_generative_model_classes = () is_encoder_decoder = True test_pruning = False test_headmasking = False @@ -892,6 +893,12 @@ class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase): def test_config(self): self.config_tester.run_common_tests() + def test_model_can_generate(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs() + for model_class in self.all_model_classes: + model = model_class(config) + self.assertTrue(model.can_generate()) + def test_save_load_strict(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs() for model_class in self.all_model_classes: