[core] remove GenerationMixin inheritance by default in PreTrainedModel (#37173)
This commit is contained in:
@@ -31,6 +31,7 @@ from transformers.testing_utils import (
|
||||
from transformers.trainer_utils import set_seed
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import (
|
||||
ModelTesterMixin,
|
||||
@@ -314,6 +315,15 @@ class SpeechT5ForSpeechToTextTester:
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
|
||||
def get_subsampled_output_lengths(self, input_lengths):
|
||||
"""
|
||||
Computes the output length of the convolutional layers
|
||||
"""
|
||||
for stride in self.conv_stride:
|
||||
input_lengths = (input_lengths // stride) - 1
|
||||
|
||||
return input_lengths
|
||||
|
||||
def create_and_check_model_forward(self, config, inputs_dict):
|
||||
model = SpeechT5ForSpeechToText(config=config).to(torch_device).eval()
|
||||
|
||||
@@ -359,10 +369,8 @@ class SpeechT5ForSpeechToTextTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase):
|
||||
class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase, GenerationTesterMixin):
|
||||
all_model_classes = (SpeechT5ForSpeechToText,) if is_torch_available() else ()
|
||||
# Doesn't run generation tests. TODO eustache/joao: shape checks probably need an update
|
||||
all_generative_model_classes = ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
@@ -727,6 +735,18 @@ class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase):
|
||||
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
|
||||
module.masked_spec_embed.data.fill_(3)
|
||||
|
||||
@unittest.skip(reason="Temporarily broken") # TODO (joao, eustache): have a look at this test
|
||||
def test_generate_with_head_masking(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Temporarily broken") # TODO (joao, eustache): have a look at this test
|
||||
def test_generate_without_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Very flaky") # TODO (joao, eustache): have a look at this test
|
||||
def test_generate_continue_from_past_key_values(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
|
||||
@@ -1720,8 +1720,8 @@ class ModelUtilsTest(TestCasePlus):
|
||||
self.assertTrue("" == cl.out)
|
||||
self.assertTrue(can_generate)
|
||||
|
||||
# 4 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited
|
||||
# `GenerationMixin`)
|
||||
# 4 - Legacy: models with a custom `prepare_inputs_for_generation` can generate (it was assumed
|
||||
# they inherited `GenerationMixin`). Deprecated in v4.45 and removed in v4.51.
|
||||
class DummyBertWithPrepareInputs(BertModel):
|
||||
def prepare_inputs_for_generation(self):
|
||||
pass
|
||||
@@ -1729,7 +1729,7 @@ class ModelUtilsTest(TestCasePlus):
|
||||
with CaptureLogger(logger) as cl:
|
||||
can_generate = DummyBertWithPrepareInputs.can_generate()
|
||||
self.assertTrue("it doesn't directly inherit from `GenerationMixin`" in cl.out)
|
||||
self.assertTrue(can_generate)
|
||||
self.assertFalse(can_generate)
|
||||
|
||||
def test_save_and_load_config_with_custom_generation(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user