[core] remove GenerationMixin inheritance by default in PreTrainedModel (#37173)

This commit is contained in:
Joao Gante
2025-04-08 16:42:05 +01:00
committed by GitHub
parent aab0878327
commit 4321b0648c
10 changed files with 54 additions and 83 deletions

View File

@@ -31,6 +31,7 @@ from transformers.testing_utils import (
from transformers.trainer_utils import set_seed
from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
ModelTesterMixin,
@@ -314,6 +315,15 @@ class SpeechT5ForSpeechToTextTester:
vocab_size=self.vocab_size,
)
def get_subsampled_output_lengths(self, input_lengths):
"""
Computes the output length of the convolutional layers
"""
for stride in self.conv_stride:
input_lengths = (input_lengths // stride) - 1
return input_lengths
def create_and_check_model_forward(self, config, inputs_dict):
model = SpeechT5ForSpeechToText(config=config).to(torch_device).eval()
@@ -359,10 +369,8 @@ class SpeechT5ForSpeechToTextTester:
@require_torch
class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase):
class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase, GenerationTesterMixin):
all_model_classes = (SpeechT5ForSpeechToText,) if is_torch_available() else ()
# Doesn't run generation tests. TODO eustache/joao: shape checks probably need an update
all_generative_model_classes = ()
is_encoder_decoder = True
test_pruning = False
test_headmasking = False
@@ -727,6 +735,18 @@ class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase):
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
module.masked_spec_embed.data.fill_(3)
@unittest.skip(reason="Temporarily broken") # TODO (joao, eustache): have a look at this test
def test_generate_with_head_masking(self):
pass
@unittest.skip(reason="Temporarily broken") # TODO (joao, eustache): have a look at this test
def test_generate_without_input_ids(self):
pass
@unittest.skip(reason="Very flaky") # TODO (joao, eustache): have a look at this test
def test_generate_continue_from_past_key_values(self):
pass
@require_torch
@require_sentencepiece

View File

@@ -1720,8 +1720,8 @@ class ModelUtilsTest(TestCasePlus):
self.assertTrue("" == cl.out)
self.assertTrue(can_generate)
# 4 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited
# `GenerationMixin`)
# 4 - Legacy: models with a custom `prepare_inputs_for_generation` can generate (it was assumed
# they inherited `GenerationMixin`). Deprecated in v4.45 and removed in v4.51.
class DummyBertWithPrepareInputs(BertModel):
def prepare_inputs_for_generation(self):
pass
@@ -1729,7 +1729,7 @@ class ModelUtilsTest(TestCasePlus):
with CaptureLogger(logger) as cl:
can_generate = DummyBertWithPrepareInputs.can_generate()
self.assertTrue("it doesn't directly inherit from `GenerationMixin`" in cl.out)
self.assertTrue(can_generate)
self.assertFalse(can_generate)
def test_save_and_load_config_with_custom_generation(self):
"""