[core] remove GenerationMixin inheritance by default in PreTrainedModel (#37173)
This commit is contained in:
@@ -1430,27 +1430,6 @@ class GenerationMixin:
|
||||
|
||||
return transition_scores
|
||||
|
||||
def _validate_model_class(self):
|
||||
"""
|
||||
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
|
||||
right class to use.
|
||||
"""
|
||||
# TODO(joao): remove this function in v4.50, i.e. when we remove the inheritance of `GenerationMixin` from
|
||||
# `PreTrainedModel`. With that inheritance removed, all model classes inheriting from `GenerationMixin` can
|
||||
# safely call `GenerationMixin.generate`
|
||||
if not self.can_generate():
|
||||
terminations_with_generation_support = [
|
||||
"ForCausalLM",
|
||||
"ForConditionalGeneration",
|
||||
"ForSpeechSeq2Seq",
|
||||
"ForVision2Seq",
|
||||
]
|
||||
raise TypeError(
|
||||
f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as "
|
||||
"it doesn't have a language model head. Classes that support generation often end in one of these "
|
||||
f"names: {terminations_with_generation_support}."
|
||||
)
|
||||
|
||||
def _validate_assistant(self, assistant_model, tokenizer, assistant_tokenizer):
|
||||
if assistant_model is None:
|
||||
return
|
||||
@@ -2213,7 +2192,6 @@ class GenerationMixin:
|
||||
"""
|
||||
|
||||
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
||||
self._validate_model_class()
|
||||
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
|
||||
assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ if is_torchao_available():
|
||||
from .activations import get_activation
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .generation import CompileConfig, GenerationConfig, GenerationMixin
|
||||
from .generation import CompileConfig, GenerationConfig
|
||||
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
|
||||
from .integrations.accelerate import find_tied_parameters, init_empty_weights
|
||||
from .integrations.deepspeed import _load_state_dict_into_zero3_model, is_deepspeed_available
|
||||
@@ -1704,8 +1704,7 @@ class ModuleUtilsMixin:
|
||||
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
|
||||
|
||||
|
||||
# TODO (joao): remove `GenerationMixin` inheritance in v4.50
|
||||
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin):
|
||||
class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMixin):
|
||||
r"""
|
||||
Base class for all models.
|
||||
|
||||
@@ -2157,12 +2156,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
continue
|
||||
if "PreTrainedModel" not in str(base) and base.can_generate():
|
||||
return True
|
||||
# BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this
|
||||
# Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this
|
||||
# was how we detected whether a model could generate.
|
||||
if "GenerationMixin" not in str(cls.prepare_inputs_for_generation):
|
||||
logger.warning_once(
|
||||
if hasattr(cls, "prepare_inputs_for_generation"): # implicit: doesn't inherit `GenerationMixin`
|
||||
logger.warning(
|
||||
f"{cls.__name__} has generative capabilities, as `prepare_inputs_for_generation` is explicitly "
|
||||
"overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, "
|
||||
"defined. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, "
|
||||
"`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability "
|
||||
"to call `generate` and other related functions."
|
||||
"\n - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the "
|
||||
@@ -2172,7 +2171,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
"\n - If you are not the owner of the model architecture class, please contact the model code owner "
|
||||
"to update it."
|
||||
)
|
||||
return True
|
||||
# Otherwise, can't generate
|
||||
return False
|
||||
|
||||
|
||||
@@ -730,8 +730,12 @@ def add_generation_mixin_to_remote_model(model_class):
|
||||
|
||||
# 3. Prior to v4.45, we could detect whether a model was `generate`-compatible if it had its own `generate` and/or
|
||||
# `prepare_inputs_for_generation` method.
|
||||
has_custom_generate = "GenerationMixin" not in str(getattr(model_class, "generate"))
|
||||
has_custom_prepare_inputs = "GenerationMixin" not in str(getattr(model_class, "prepare_inputs_for_generation"))
|
||||
has_custom_generate = hasattr(model_class, "generate") and "GenerationMixin" not in str(
|
||||
getattr(model_class, "generate")
|
||||
)
|
||||
has_custom_prepare_inputs = hasattr(model_class, "prepare_inputs_for_generation") and "GenerationMixin" not in str(
|
||||
getattr(model_class, "prepare_inputs_for_generation")
|
||||
)
|
||||
if has_custom_generate or has_custom_prepare_inputs:
|
||||
model_class_with_generation_mixin = type(
|
||||
model_class.__name__, (model_class, GenerationMixin), {**model_class.__dict__}
|
||||
|
||||
@@ -1512,8 +1512,8 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
@classmethod
|
||||
def can_generate(cls) -> bool:
|
||||
"""
|
||||
Legacy correction: BertForMaskedLM can't call `generate()` from GenerationMixin.
|
||||
Remove after v4.50, when we stop making `PreTrainedModel` inherit from `GenerationMixin`.
|
||||
Legacy correction: BertForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a
|
||||
`prepare_inputs_for_generation` method.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
@@ -1328,8 +1328,8 @@ class ErnieForMaskedLM(ErniePreTrainedModel):
|
||||
@classmethod
|
||||
def can_generate(cls) -> bool:
|
||||
"""
|
||||
Legacy correction: ErnieForMaskedLM can't call `generate()` from GenerationMixin.
|
||||
Remove after v4.50, when we stop making `PreTrainedModel` inherit from `GenerationMixin`.
|
||||
Legacy correction: ErnieForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a
|
||||
`prepare_inputs_for_generation` method.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
||||
from ...generation import GenerationConfig, GenerationMixin, LogitsProcessorList, StoppingCriteriaList
|
||||
from ...modeling_outputs import ModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
@@ -1122,7 +1122,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
||||
""",
|
||||
RAG_START_DOCSTRING,
|
||||
)
|
||||
class RagTokenForGeneration(RagPreTrainedModel):
|
||||
class RagTokenForGeneration(RagPreTrainedModel, GenerationMixin):
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[PretrainedConfig] = None,
|
||||
|
||||
@@ -999,6 +999,14 @@ class RemBertForMaskedLM(RemBertPreTrainedModel):
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
|
||||
@classmethod
|
||||
def can_generate(cls) -> bool:
|
||||
"""
|
||||
Legacy correction: RemBertForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a
|
||||
`prepare_inputs_for_generation` method.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""RemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", REMBERT_START_DOCSTRING
|
||||
|
||||
@@ -24,6 +24,7 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ...integrations.fsdp import is_fsdp_managed_module
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
||||
@@ -2242,7 +2243,7 @@ class SpeechT5Model(SpeechT5PreTrainedModel):
|
||||
"""SpeechT5 Model with a speech encoder and a text decoder.""",
|
||||
SPEECHT5_START_DOCSTRING,
|
||||
)
|
||||
class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel):
|
||||
class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["text_decoder_postnet.lm_head.weight"]
|
||||
|
||||
def __init__(self, config: SpeechT5Config):
|
||||
@@ -2413,44 +2414,6 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel):
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Note that this model doesn't inherit from the generation mixin, has unique generate function
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past_key_values is not None:
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Some generation methods already pass only the last input ID
|
||||
if decoder_input_ids.shape[1] > past_length:
|
||||
remove_prefix_length = past_length
|
||||
else:
|
||||
# Default to old behavior: keep only final ID
|
||||
remove_prefix_length = decoder_input_ids.shape[1] - 1
|
||||
|
||||
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
|
||||
|
||||
return {
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
|
||||
@@ -31,6 +31,7 @@ from transformers.testing_utils import (
|
||||
from transformers.trainer_utils import set_seed
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import (
|
||||
ModelTesterMixin,
|
||||
@@ -314,6 +315,15 @@ class SpeechT5ForSpeechToTextTester:
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
|
||||
def get_subsampled_output_lengths(self, input_lengths):
|
||||
"""
|
||||
Computes the output length of the convolutional layers
|
||||
"""
|
||||
for stride in self.conv_stride:
|
||||
input_lengths = (input_lengths // stride) - 1
|
||||
|
||||
return input_lengths
|
||||
|
||||
def create_and_check_model_forward(self, config, inputs_dict):
|
||||
model = SpeechT5ForSpeechToText(config=config).to(torch_device).eval()
|
||||
|
||||
@@ -359,10 +369,8 @@ class SpeechT5ForSpeechToTextTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase):
|
||||
class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase, GenerationTesterMixin):
|
||||
all_model_classes = (SpeechT5ForSpeechToText,) if is_torch_available() else ()
|
||||
# Doesn't run generation tests. TODO eustache/joao: shape checks probably need an update
|
||||
all_generative_model_classes = ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
@@ -727,6 +735,18 @@ class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase):
|
||||
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
|
||||
module.masked_spec_embed.data.fill_(3)
|
||||
|
||||
@unittest.skip(reason="Temporarily broken") # TODO (joao, eustache): have a look at this test
|
||||
def test_generate_with_head_masking(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Temporarily broken") # TODO (joao, eustache): have a look at this test
|
||||
def test_generate_without_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Very flaky") # TODO (joao, eustache): have a look at this test
|
||||
def test_generate_continue_from_past_key_values(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
|
||||
@@ -1720,8 +1720,8 @@ class ModelUtilsTest(TestCasePlus):
|
||||
self.assertTrue("" == cl.out)
|
||||
self.assertTrue(can_generate)
|
||||
|
||||
# 4 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited
|
||||
# `GenerationMixin`)
|
||||
# 4 - Legacy: models with a custom `prepare_inputs_for_generation` can generate (it was assumed
|
||||
# they inherited `GenerationMixin`). Deprecated in v4.45 and removed in v4.51.
|
||||
class DummyBertWithPrepareInputs(BertModel):
|
||||
def prepare_inputs_for_generation(self):
|
||||
pass
|
||||
@@ -1729,7 +1729,7 @@ class ModelUtilsTest(TestCasePlus):
|
||||
with CaptureLogger(logger) as cl:
|
||||
can_generate = DummyBertWithPrepareInputs.can_generate()
|
||||
self.assertTrue("it doesn't directly inherit from `GenerationMixin`" in cl.out)
|
||||
self.assertTrue(can_generate)
|
||||
self.assertFalse(can_generate)
|
||||
|
||||
def test_save_and_load_config_with_custom_generation(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user