[core] remove GenerationMixin inheritance by default in PreTrainedModel (#37173)

This commit is contained in:
Joao Gante
2025-04-08 16:42:05 +01:00
committed by GitHub
parent aab0878327
commit 4321b0648c
10 changed files with 54 additions and 83 deletions

View File

@@ -1430,27 +1430,6 @@ class GenerationMixin:
return transition_scores
def _validate_model_class(self):
"""
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
# TODO(joao): remove this function in v4.50, i.e. when we remove the inheritance of `GenerationMixin` from
# `PreTrainedModel`. With that inheritance removed, all model classes inheriting from `GenerationMixin` can
# safely call `GenerationMixin.generate`
if not self.can_generate():
terminations_with_generation_support = [
"ForCausalLM",
"ForConditionalGeneration",
"ForSpeechSeq2Seq",
"ForVision2Seq",
]
raise TypeError(
f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as "
"it doesn't have a language model head. Classes that support generation often end in one of these "
f"names: {terminations_with_generation_support}."
)
def _validate_assistant(self, assistant_model, tokenizer, assistant_tokenizer):
if assistant_model is None:
return
@@ -2213,7 +2192,6 @@ class GenerationMixin:
"""
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation

View File

@@ -55,7 +55,7 @@ if is_torchao_available():
from .activations import get_activation
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .generation import CompileConfig, GenerationConfig, GenerationMixin
from .generation import CompileConfig, GenerationConfig
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
from .integrations.accelerate import find_tied_parameters, init_empty_weights
from .integrations.deepspeed import _load_state_dict_into_zero3_model, is_deepspeed_available
@@ -1704,8 +1704,7 @@ class ModuleUtilsMixin:
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
# TODO (joao): remove `GenerationMixin` inheritance in v4.50
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin):
class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMixin):
r"""
Base class for all models.
@@ -2157,12 +2156,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
continue
if "PreTrainedModel" not in str(base) and base.can_generate():
return True
# BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this
# Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this
# was how we detected whether a model could generate.
if "GenerationMixin" not in str(cls.prepare_inputs_for_generation):
logger.warning_once(
if hasattr(cls, "prepare_inputs_for_generation"): # implicit: doesn't inherit `GenerationMixin`
logger.warning(
f"{cls.__name__} has generative capabilities, as `prepare_inputs_for_generation` is explicitly "
"overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, "
"defined. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, "
"`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability "
"to call `generate` and other related functions."
"\n - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the "
@@ -2172,7 +2171,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"\n - If you are not the owner of the model architecture class, please contact the model code owner "
"to update it."
)
return True
# Otherwise, can't generate
return False

View File

@@ -730,8 +730,12 @@ def add_generation_mixin_to_remote_model(model_class):
# 3. Prior to v4.45, we could detect whether a model was `generate`-compatible if it had its own `generate` and/or
# `prepare_inputs_for_generation` method.
has_custom_generate = "GenerationMixin" not in str(getattr(model_class, "generate"))
has_custom_prepare_inputs = "GenerationMixin" not in str(getattr(model_class, "prepare_inputs_for_generation"))
has_custom_generate = hasattr(model_class, "generate") and "GenerationMixin" not in str(
getattr(model_class, "generate")
)
has_custom_prepare_inputs = hasattr(model_class, "prepare_inputs_for_generation") and "GenerationMixin" not in str(
getattr(model_class, "prepare_inputs_for_generation")
)
if has_custom_generate or has_custom_prepare_inputs:
model_class_with_generation_mixin = type(
model_class.__name__, (model_class, GenerationMixin), {**model_class.__dict__}

View File

@@ -1512,8 +1512,8 @@ class BertForMaskedLM(BertPreTrainedModel):
@classmethod
def can_generate(cls) -> bool:
"""
Legacy correction: BertForMaskedLM can't call `generate()` from GenerationMixin.
Remove after v4.50, when we stop making `PreTrainedModel` inherit from `GenerationMixin`.
Legacy correction: BertForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a
`prepare_inputs_for_generation` method.
"""
return False

View File

@@ -1328,8 +1328,8 @@ class ErnieForMaskedLM(ErniePreTrainedModel):
@classmethod
def can_generate(cls) -> bool:
"""
Legacy correction: ErnieForMaskedLM can't call `generate()` from GenerationMixin.
Remove after v4.50, when we stop making `PreTrainedModel` inherit from `GenerationMixin`.
Legacy correction: ErnieForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a
`prepare_inputs_for_generation` method.
"""
return False

View File

@@ -22,7 +22,7 @@ import torch
from torch import nn
from ...configuration_utils import PretrainedConfig
from ...generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from ...generation import GenerationConfig, GenerationMixin, LogitsProcessorList, StoppingCriteriaList
from ...modeling_outputs import ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings
@@ -1122,7 +1122,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
""",
RAG_START_DOCSTRING,
)
class RagTokenForGeneration(RagPreTrainedModel):
class RagTokenForGeneration(RagPreTrainedModel, GenerationMixin):
def __init__(
self,
config: Optional[PretrainedConfig] = None,

View File

@@ -999,6 +999,14 @@ class RemBertForMaskedLM(RemBertPreTrainedModel):
return {"input_ids": input_ids, "attention_mask": attention_mask}
@classmethod
def can_generate(cls) -> bool:
"""
Legacy correction: RemBertForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a
`prepare_inputs_for_generation` method.
"""
return False
@add_start_docstrings(
"""RemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", REMBERT_START_DOCSTRING

View File

@@ -24,6 +24,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
@@ -2242,7 +2243,7 @@ class SpeechT5Model(SpeechT5PreTrainedModel):
"""SpeechT5 Model with a speech encoder and a text decoder.""",
SPEECHT5_START_DOCSTRING,
)
class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel):
class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["text_decoder_postnet.lm_head.weight"]
def __init__(self, config: SpeechT5Config):
@@ -2413,44 +2414,6 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel):
encoder_attentions=outputs.encoder_attentions,
)
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
):
# Note that this model doesn't inherit from the generation mixin, has unique generate function
# cut decoder_input_ids if past is used
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return {
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()

View File

@@ -31,6 +31,7 @@ from transformers.testing_utils import (
from transformers.trainer_utils import set_seed
from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
ModelTesterMixin,
@@ -314,6 +315,15 @@ class SpeechT5ForSpeechToTextTester:
vocab_size=self.vocab_size,
)
def get_subsampled_output_lengths(self, input_lengths):
"""
Computes the output length of the convolutional layers
"""
for stride in self.conv_stride:
input_lengths = (input_lengths // stride) - 1
return input_lengths
def create_and_check_model_forward(self, config, inputs_dict):
model = SpeechT5ForSpeechToText(config=config).to(torch_device).eval()
@@ -359,10 +369,8 @@ class SpeechT5ForSpeechToTextTester:
@require_torch
class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase):
class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase, GenerationTesterMixin):
all_model_classes = (SpeechT5ForSpeechToText,) if is_torch_available() else ()
# Doesn't run generation tests. TODO eustache/joao: shape checks probably need an update
all_generative_model_classes = ()
is_encoder_decoder = True
test_pruning = False
test_headmasking = False
@@ -727,6 +735,18 @@ class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase):
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
module.masked_spec_embed.data.fill_(3)
@unittest.skip(reason="Temporarily broken") # TODO (joao, eustache): have a look at this test
def test_generate_with_head_masking(self):
pass
@unittest.skip(reason="Temporarily broken") # TODO (joao, eustache): have a look at this test
def test_generate_without_input_ids(self):
pass
@unittest.skip(reason="Very flaky") # TODO (joao, eustache): have a look at this test
def test_generate_continue_from_past_key_values(self):
pass
@require_torch
@require_sentencepiece

View File

@@ -1720,8 +1720,8 @@ class ModelUtilsTest(TestCasePlus):
self.assertTrue("" == cl.out)
self.assertTrue(can_generate)
# 4 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited
# `GenerationMixin`)
# 4 - Legacy: models with a custom `prepare_inputs_for_generation` can generate (it was assumed
# they inherited `GenerationMixin`). Deprecated in v4.45 and removed in v4.51.
class DummyBertWithPrepareInputs(BertModel):
def prepare_inputs_for_generation(self):
pass
@@ -1729,7 +1729,7 @@ class ModelUtilsTest(TestCasePlus):
with CaptureLogger(logger) as cl:
can_generate = DummyBertWithPrepareInputs.can_generate()
self.assertTrue("it doesn't directly inherit from `GenerationMixin`" in cl.out)
self.assertTrue(can_generate)
self.assertFalse(can_generate)
def test_save_and_load_config_with_custom_generation(self):
"""