[core] remove GenerationMixin inheritance by default in PreTrainedModel (#37173)

This commit is contained in:
Joao Gante
2025-04-08 16:42:05 +01:00
committed by GitHub
parent aab0878327
commit 4321b0648c
10 changed files with 54 additions and 83 deletions

View File

@@ -1430,27 +1430,6 @@ class GenerationMixin:
return transition_scores return transition_scores
def _validate_model_class(self):
"""
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
# TODO(joao): remove this function in v4.50, i.e. when we remove the inheritance of `GenerationMixin` from
# `PreTrainedModel`. With that inheritance removed, all model classes inheriting from `GenerationMixin` can
# safely call `GenerationMixin.generate`
if not self.can_generate():
terminations_with_generation_support = [
"ForCausalLM",
"ForConditionalGeneration",
"ForSpeechSeq2Seq",
"ForVision2Seq",
]
raise TypeError(
f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as "
"it doesn't have a language model head. Classes that support generation often end in one of these "
f"names: {terminations_with_generation_support}."
)
def _validate_assistant(self, assistant_model, tokenizer, assistant_tokenizer): def _validate_assistant(self, assistant_model, tokenizer, assistant_tokenizer):
if assistant_model is None: if assistant_model is None:
return return
@@ -2213,7 +2192,6 @@ class GenerationMixin:
""" """
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation

View File

@@ -55,7 +55,7 @@ if is_torchao_available():
from .activations import get_activation from .activations import get_activation
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
from .generation import CompileConfig, GenerationConfig, GenerationMixin from .generation import CompileConfig, GenerationConfig
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
from .integrations.accelerate import find_tied_parameters, init_empty_weights from .integrations.accelerate import find_tied_parameters, init_empty_weights
from .integrations.deepspeed import _load_state_dict_into_zero3_model, is_deepspeed_available from .integrations.deepspeed import _load_state_dict_into_zero3_model, is_deepspeed_available
@@ -1704,8 +1704,7 @@ class ModuleUtilsMixin:
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings) return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
# TODO (joao): remove `GenerationMixin` inheritance in v4.50 class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMixin):
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin):
r""" r"""
Base class for all models. Base class for all models.
@@ -2157,12 +2156,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
continue continue
if "PreTrainedModel" not in str(base) and base.can_generate(): if "PreTrainedModel" not in str(base) and base.can_generate():
return True return True
# BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this # Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this
# was how we detected whether a model could generate. # was how we detected whether a model could generate.
if "GenerationMixin" not in str(cls.prepare_inputs_for_generation): if hasattr(cls, "prepare_inputs_for_generation"): # implicit: doesn't inherit `GenerationMixin`
logger.warning_once( logger.warning(
f"{cls.__name__} has generative capabilities, as `prepare_inputs_for_generation` is explicitly " f"{cls.__name__} has generative capabilities, as `prepare_inputs_for_generation` is explicitly "
"overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, " "defined. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, "
"`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability " "`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability "
"to call `generate` and other related functions." "to call `generate` and other related functions."
"\n - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the " "\n - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the "
@@ -2172,7 +2171,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"\n - If you are not the owner of the model architecture class, please contact the model code owner " "\n - If you are not the owner of the model architecture class, please contact the model code owner "
"to update it." "to update it."
) )
return True
# Otherwise, can't generate # Otherwise, can't generate
return False return False

View File

@@ -730,8 +730,12 @@ def add_generation_mixin_to_remote_model(model_class):
# 3. Prior to v4.45, we could detect whether a model was `generate`-compatible if it had its own `generate` and/or # 3. Prior to v4.45, we could detect whether a model was `generate`-compatible if it had its own `generate` and/or
# `prepare_inputs_for_generation` method. # `prepare_inputs_for_generation` method.
has_custom_generate = "GenerationMixin" not in str(getattr(model_class, "generate")) has_custom_generate = hasattr(model_class, "generate") and "GenerationMixin" not in str(
has_custom_prepare_inputs = "GenerationMixin" not in str(getattr(model_class, "prepare_inputs_for_generation")) getattr(model_class, "generate")
)
has_custom_prepare_inputs = hasattr(model_class, "prepare_inputs_for_generation") and "GenerationMixin" not in str(
getattr(model_class, "prepare_inputs_for_generation")
)
if has_custom_generate or has_custom_prepare_inputs: if has_custom_generate or has_custom_prepare_inputs:
model_class_with_generation_mixin = type( model_class_with_generation_mixin = type(
model_class.__name__, (model_class, GenerationMixin), {**model_class.__dict__} model_class.__name__, (model_class, GenerationMixin), {**model_class.__dict__}

View File

@@ -1512,8 +1512,8 @@ class BertForMaskedLM(BertPreTrainedModel):
@classmethod @classmethod
def can_generate(cls) -> bool: def can_generate(cls) -> bool:
""" """
Legacy correction: BertForMaskedLM can't call `generate()` from GenerationMixin. Legacy correction: BertForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a
Remove after v4.50, when we stop making `PreTrainedModel` inherit from `GenerationMixin`. `prepare_inputs_for_generation` method.
""" """
return False return False

View File

@@ -1328,8 +1328,8 @@ class ErnieForMaskedLM(ErniePreTrainedModel):
@classmethod @classmethod
def can_generate(cls) -> bool: def can_generate(cls) -> bool:
""" """
Legacy correction: ErnieForMaskedLM can't call `generate()` from GenerationMixin. Legacy correction: ErnieForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a
Remove after v4.50, when we stop making `PreTrainedModel` inherit from `GenerationMixin`. `prepare_inputs_for_generation` method.
""" """
return False return False

View File

@@ -22,7 +22,7 @@ import torch
from torch import nn from torch import nn
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList from ...generation import GenerationConfig, GenerationMixin, LogitsProcessorList, StoppingCriteriaList
from ...modeling_outputs import ModelOutput from ...modeling_outputs import ModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings
@@ -1122,7 +1122,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
""", """,
RAG_START_DOCSTRING, RAG_START_DOCSTRING,
) )
class RagTokenForGeneration(RagPreTrainedModel): class RagTokenForGeneration(RagPreTrainedModel, GenerationMixin):
def __init__( def __init__(
self, self,
config: Optional[PretrainedConfig] = None, config: Optional[PretrainedConfig] = None,

View File

@@ -999,6 +999,14 @@ class RemBertForMaskedLM(RemBertPreTrainedModel):
return {"input_ids": input_ids, "attention_mask": attention_mask} return {"input_ids": input_ids, "attention_mask": attention_mask}
@classmethod
def can_generate(cls) -> bool:
"""
Legacy correction: RemBertForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a
`prepare_inputs_for_generation` method.
"""
return False
@add_start_docstrings( @add_start_docstrings(
"""RemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", REMBERT_START_DOCSTRING """RemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", REMBERT_START_DOCSTRING

View File

@@ -24,6 +24,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
@@ -2242,7 +2243,7 @@ class SpeechT5Model(SpeechT5PreTrainedModel):
"""SpeechT5 Model with a speech encoder and a text decoder.""", """SpeechT5 Model with a speech encoder and a text decoder.""",
SPEECHT5_START_DOCSTRING, SPEECHT5_START_DOCSTRING,
) )
class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel): class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["text_decoder_postnet.lm_head.weight"] _tied_weights_keys = ["text_decoder_postnet.lm_head.weight"]
def __init__(self, config: SpeechT5Config): def __init__(self, config: SpeechT5Config):
@@ -2413,44 +2414,6 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel):
encoder_attentions=outputs.encoder_attentions, encoder_attentions=outputs.encoder_attentions,
) )
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
):
# Note that this model doesn't inherit from the generation mixin, has unique generate function
# cut decoder_input_ids if past is used
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return {
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
@staticmethod @staticmethod
def _reorder_cache(past_key_values, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()

View File

@@ -31,6 +31,7 @@ from transformers.testing_utils import (
from transformers.trainer_utils import set_seed from transformers.trainer_utils import set_seed
from transformers.utils import cached_property from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ( from ...test_modeling_common import (
ModelTesterMixin, ModelTesterMixin,
@@ -314,6 +315,15 @@ class SpeechT5ForSpeechToTextTester:
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
) )
def get_subsampled_output_lengths(self, input_lengths):
"""
Computes the output length of the convolutional layers
"""
for stride in self.conv_stride:
input_lengths = (input_lengths // stride) - 1
return input_lengths
def create_and_check_model_forward(self, config, inputs_dict): def create_and_check_model_forward(self, config, inputs_dict):
model = SpeechT5ForSpeechToText(config=config).to(torch_device).eval() model = SpeechT5ForSpeechToText(config=config).to(torch_device).eval()
@@ -359,10 +369,8 @@ class SpeechT5ForSpeechToTextTester:
@require_torch @require_torch
class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase): class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase, GenerationTesterMixin):
all_model_classes = (SpeechT5ForSpeechToText,) if is_torch_available() else () all_model_classes = (SpeechT5ForSpeechToText,) if is_torch_available() else ()
# Doesn't run generation tests. TODO eustache/joao: shape checks probably need an update
all_generative_model_classes = ()
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_headmasking = False test_headmasking = False
@@ -727,6 +735,18 @@ class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase):
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
module.masked_spec_embed.data.fill_(3) module.masked_spec_embed.data.fill_(3)
@unittest.skip(reason="Temporarily broken") # TODO (joao, eustache): have a look at this test
def test_generate_with_head_masking(self):
pass
@unittest.skip(reason="Temporarily broken") # TODO (joao, eustache): have a look at this test
def test_generate_without_input_ids(self):
pass
@unittest.skip(reason="Very flaky") # TODO (joao, eustache): have a look at this test
def test_generate_continue_from_past_key_values(self):
pass
@require_torch @require_torch
@require_sentencepiece @require_sentencepiece

View File

@@ -1720,8 +1720,8 @@ class ModelUtilsTest(TestCasePlus):
self.assertTrue("" == cl.out) self.assertTrue("" == cl.out)
self.assertTrue(can_generate) self.assertTrue(can_generate)
# 4 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited # 4 - Legacy: models with a custom `prepare_inputs_for_generation` can generate (it was assumed
# `GenerationMixin`) # they inherited `GenerationMixin`). Deprecated in v4.45 and removed in v4.51.
class DummyBertWithPrepareInputs(BertModel): class DummyBertWithPrepareInputs(BertModel):
def prepare_inputs_for_generation(self): def prepare_inputs_for_generation(self):
pass pass
@@ -1729,7 +1729,7 @@ class ModelUtilsTest(TestCasePlus):
with CaptureLogger(logger) as cl: with CaptureLogger(logger) as cl:
can_generate = DummyBertWithPrepareInputs.can_generate() can_generate = DummyBertWithPrepareInputs.can_generate()
self.assertTrue("it doesn't directly inherit from `GenerationMixin`" in cl.out) self.assertTrue("it doesn't directly inherit from `GenerationMixin`" in cl.out)
self.assertTrue(can_generate) self.assertFalse(can_generate)
def test_save_and_load_config_with_custom_generation(self): def test_save_and_load_config_with_custom_generation(self):
""" """