From 4321b0648c4d80f8d0e3b02d555521c7e1483f46 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 8 Apr 2025 16:42:05 +0100 Subject: [PATCH] [core] remove `GenerationMixin` inheritance by default in `PreTrainedModel` (#37173) --- src/transformers/generation/utils.py | 22 ---------- src/transformers/modeling_utils.py | 14 +++---- src/transformers/models/auto/auto_factory.py | 8 +++- src/transformers/models/bert/modeling_bert.py | 4 +- .../models/ernie/modeling_ernie.py | 4 +- src/transformers/models/rag/modeling_rag.py | 4 +- .../models/rembert/modeling_rembert.py | 8 ++++ .../models/speecht5/modeling_speecht5.py | 41 +------------------ .../models/speecht5/test_modeling_speecht5.py | 26 ++++++++++-- tests/utils/test_modeling_utils.py | 6 +-- 10 files changed, 54 insertions(+), 83 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index b68de89f66..09cbc9c446 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1430,27 +1430,6 @@ class GenerationMixin: return transition_scores - def _validate_model_class(self): - """ - Confirms that the model class is compatible with generation. If not, raises an exception that points to the - right class to use. - """ - # TODO(joao): remove this function in v4.50, i.e. when we remove the inheritance of `GenerationMixin` from - # `PreTrainedModel`. With that inheritance removed, all model classes inheriting from `GenerationMixin` can - # safely call `GenerationMixin.generate` - if not self.can_generate(): - terminations_with_generation_support = [ - "ForCausalLM", - "ForConditionalGeneration", - "ForSpeechSeq2Seq", - "ForVision2Seq", - ] - raise TypeError( - f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as " - "it doesn't have a language model head. Classes that support generation often end in one of these " - f"names: {terminations_with_generation_support}." - ) - def _validate_assistant(self, assistant_model, tokenizer, assistant_tokenizer): if assistant_model is None: return @@ -2213,7 +2192,6 @@ class GenerationMixin: """ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call - self._validate_model_class() tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index cac9ba1eea..b8a7831176 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -55,7 +55,7 @@ if is_torchao_available(): from .activations import get_activation from .configuration_utils import PretrainedConfig from .dynamic_module_utils import custom_object_save -from .generation import CompileConfig, GenerationConfig, GenerationMixin +from .generation import CompileConfig, GenerationConfig from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled from .integrations.accelerate import find_tied_parameters, init_empty_weights from .integrations.deepspeed import _load_state_dict_into_zero3_model, is_deepspeed_available @@ -1704,8 +1704,7 @@ class ModuleUtilsMixin: return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings) -# TODO (joao): remove `GenerationMixin` inheritance in v4.50 -class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin): +class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMixin): r""" Base class for all models. @@ -2157,12 +2156,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix continue if "PreTrainedModel" not in str(base) and base.can_generate(): return True - # BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this + # Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this # was how we detected whether a model could generate. - if "GenerationMixin" not in str(cls.prepare_inputs_for_generation): - logger.warning_once( + if hasattr(cls, "prepare_inputs_for_generation"): # implicit: doesn't inherit `GenerationMixin` + logger.warning( f"{cls.__name__} has generative capabilities, as `prepare_inputs_for_generation` is explicitly " - "overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, " + "defined. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, " "`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability " "to call `generate` and other related functions." "\n - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the " @@ -2172,7 +2171,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix "\n - If you are not the owner of the model architecture class, please contact the model code owner " "to update it." ) - return True # Otherwise, can't generate return False diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index d21a16beb1..9c5690b496 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -730,8 +730,12 @@ def add_generation_mixin_to_remote_model(model_class): # 3. Prior to v4.45, we could detect whether a model was `generate`-compatible if it had its own `generate` and/or # `prepare_inputs_for_generation` method. - has_custom_generate = "GenerationMixin" not in str(getattr(model_class, "generate")) - has_custom_prepare_inputs = "GenerationMixin" not in str(getattr(model_class, "prepare_inputs_for_generation")) + has_custom_generate = hasattr(model_class, "generate") and "GenerationMixin" not in str( + getattr(model_class, "generate") + ) + has_custom_prepare_inputs = hasattr(model_class, "prepare_inputs_for_generation") and "GenerationMixin" not in str( + getattr(model_class, "prepare_inputs_for_generation") + ) if has_custom_generate or has_custom_prepare_inputs: model_class_with_generation_mixin = type( model_class.__name__, (model_class, GenerationMixin), {**model_class.__dict__} diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 38d2c8c8b5..3a311f4a73 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -1512,8 +1512,8 @@ class BertForMaskedLM(BertPreTrainedModel): @classmethod def can_generate(cls) -> bool: """ - Legacy correction: BertForMaskedLM can't call `generate()` from GenerationMixin. - Remove after v4.50, when we stop making `PreTrainedModel` inherit from `GenerationMixin`. + Legacy correction: BertForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a + `prepare_inputs_for_generation` method. """ return False diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 559078b1ef..221dd57485 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -1328,8 +1328,8 @@ class ErnieForMaskedLM(ErniePreTrainedModel): @classmethod def can_generate(cls) -> bool: """ - Legacy correction: ErnieForMaskedLM can't call `generate()` from GenerationMixin. - Remove after v4.50, when we stop making `PreTrainedModel` inherit from `GenerationMixin`. + Legacy correction: ErnieForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a + `prepare_inputs_for_generation` method. """ return False diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index c2258d9767..822347950a 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -22,7 +22,7 @@ import torch from torch import nn from ...configuration_utils import PretrainedConfig -from ...generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList +from ...generation import GenerationConfig, GenerationMixin, LogitsProcessorList, StoppingCriteriaList from ...modeling_outputs import ModelOutput from ...modeling_utils import PreTrainedModel from ...utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings @@ -1122,7 +1122,7 @@ class RagSequenceForGeneration(RagPreTrainedModel): """, RAG_START_DOCSTRING, ) -class RagTokenForGeneration(RagPreTrainedModel): +class RagTokenForGeneration(RagPreTrainedModel, GenerationMixin): def __init__( self, config: Optional[PretrainedConfig] = None, diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index a8942b1e7c..ef3250a712 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -999,6 +999,14 @@ class RemBertForMaskedLM(RemBertPreTrainedModel): return {"input_ids": input_ids, "attention_mask": attention_mask} + @classmethod + def can_generate(cls) -> bool: + """ + Legacy correction: RemBertForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a + `prepare_inputs_for_generation` method. + """ + return False + @add_start_docstrings( """RemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", REMBERT_START_DOCSTRING diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index d85e52924a..a46ab875d4 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask @@ -2242,7 +2243,7 @@ class SpeechT5Model(SpeechT5PreTrainedModel): """SpeechT5 Model with a speech encoder and a text decoder.""", SPEECHT5_START_DOCSTRING, ) -class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel): +class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel, GenerationMixin): _tied_weights_keys = ["text_decoder_postnet.lm_head.weight"] def __init__(self, config: SpeechT5Config): @@ -2413,44 +2414,6 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel): encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # Note that this model doesn't inherit from the generation mixin, has unique generate function - - # cut decoder_input_ids if past is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () diff --git a/tests/models/speecht5/test_modeling_speecht5.py b/tests/models/speecht5/test_modeling_speecht5.py index dcff134ef8..3655a188d8 100644 --- a/tests/models/speecht5/test_modeling_speecht5.py +++ b/tests/models/speecht5/test_modeling_speecht5.py @@ -31,6 +31,7 @@ from transformers.testing_utils import ( from transformers.trainer_utils import set_seed from transformers.utils import cached_property +from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ( ModelTesterMixin, @@ -314,6 +315,15 @@ class SpeechT5ForSpeechToTextTester: vocab_size=self.vocab_size, ) + def get_subsampled_output_lengths(self, input_lengths): + """ + Computes the output length of the convolutional layers + """ + for stride in self.conv_stride: + input_lengths = (input_lengths // stride) - 1 + + return input_lengths + def create_and_check_model_forward(self, config, inputs_dict): model = SpeechT5ForSpeechToText(config=config).to(torch_device).eval() @@ -359,10 +369,8 @@ class SpeechT5ForSpeechToTextTester: @require_torch -class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase): +class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase, GenerationTesterMixin): all_model_classes = (SpeechT5ForSpeechToText,) if is_torch_available() else () - # Doesn't run generation tests. TODO eustache/joao: shape checks probably need an update - all_generative_model_classes = () is_encoder_decoder = True test_pruning = False test_headmasking = False @@ -727,6 +735,18 @@ class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase): if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: module.masked_spec_embed.data.fill_(3) + @unittest.skip(reason="Temporarily broken") # TODO (joao, eustache): have a look at this test + def test_generate_with_head_masking(self): + pass + + @unittest.skip(reason="Temporarily broken") # TODO (joao, eustache): have a look at this test + def test_generate_without_input_ids(self): + pass + + @unittest.skip(reason="Very flaky") # TODO (joao, eustache): have a look at this test + def test_generate_continue_from_past_key_values(self): + pass + @require_torch @require_sentencepiece diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 0d2f7e8271..b75d118708 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1720,8 +1720,8 @@ class ModelUtilsTest(TestCasePlus): self.assertTrue("" == cl.out) self.assertTrue(can_generate) - # 4 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited - # `GenerationMixin`) + # 4 - Legacy: models with a custom `prepare_inputs_for_generation` can generate (it was assumed + # they inherited `GenerationMixin`). Deprecated in v4.45 and removed in v4.51. class DummyBertWithPrepareInputs(BertModel): def prepare_inputs_for_generation(self): pass @@ -1729,7 +1729,7 @@ class ModelUtilsTest(TestCasePlus): with CaptureLogger(logger) as cl: can_generate = DummyBertWithPrepareInputs.can_generate() self.assertTrue("it doesn't directly inherit from `GenerationMixin`" in cl.out) - self.assertTrue(can_generate) + self.assertFalse(can_generate) def test_save_and_load_config_with_custom_generation(self): """