[core] remove GenerationMixin inheritance by default in PreTrainedModel (#37173)
This commit is contained in:
@@ -1430,27 +1430,6 @@ class GenerationMixin:
|
|||||||
|
|
||||||
return transition_scores
|
return transition_scores
|
||||||
|
|
||||||
def _validate_model_class(self):
|
|
||||||
"""
|
|
||||||
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
|
|
||||||
right class to use.
|
|
||||||
"""
|
|
||||||
# TODO(joao): remove this function in v4.50, i.e. when we remove the inheritance of `GenerationMixin` from
|
|
||||||
# `PreTrainedModel`. With that inheritance removed, all model classes inheriting from `GenerationMixin` can
|
|
||||||
# safely call `GenerationMixin.generate`
|
|
||||||
if not self.can_generate():
|
|
||||||
terminations_with_generation_support = [
|
|
||||||
"ForCausalLM",
|
|
||||||
"ForConditionalGeneration",
|
|
||||||
"ForSpeechSeq2Seq",
|
|
||||||
"ForVision2Seq",
|
|
||||||
]
|
|
||||||
raise TypeError(
|
|
||||||
f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as "
|
|
||||||
"it doesn't have a language model head. Classes that support generation often end in one of these "
|
|
||||||
f"names: {terminations_with_generation_support}."
|
|
||||||
)
|
|
||||||
|
|
||||||
def _validate_assistant(self, assistant_model, tokenizer, assistant_tokenizer):
|
def _validate_assistant(self, assistant_model, tokenizer, assistant_tokenizer):
|
||||||
if assistant_model is None:
|
if assistant_model is None:
|
||||||
return
|
return
|
||||||
@@ -2213,7 +2192,6 @@ class GenerationMixin:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
||||||
self._validate_model_class()
|
|
||||||
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
|
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
|
||||||
assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation
|
assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation
|
||||||
|
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ if is_torchao_available():
|
|||||||
from .activations import get_activation
|
from .activations import get_activation
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .dynamic_module_utils import custom_object_save
|
from .dynamic_module_utils import custom_object_save
|
||||||
from .generation import CompileConfig, GenerationConfig, GenerationMixin
|
from .generation import CompileConfig, GenerationConfig
|
||||||
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
|
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
|
||||||
from .integrations.accelerate import find_tied_parameters, init_empty_weights
|
from .integrations.accelerate import find_tied_parameters, init_empty_weights
|
||||||
from .integrations.deepspeed import _load_state_dict_into_zero3_model, is_deepspeed_available
|
from .integrations.deepspeed import _load_state_dict_into_zero3_model, is_deepspeed_available
|
||||||
@@ -1704,8 +1704,7 @@ class ModuleUtilsMixin:
|
|||||||
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
|
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
|
||||||
|
|
||||||
|
|
||||||
# TODO (joao): remove `GenerationMixin` inheritance in v4.50
|
class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMixin):
|
||||||
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin):
|
|
||||||
r"""
|
r"""
|
||||||
Base class for all models.
|
Base class for all models.
|
||||||
|
|
||||||
@@ -2157,12 +2156,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
continue
|
continue
|
||||||
if "PreTrainedModel" not in str(base) and base.can_generate():
|
if "PreTrainedModel" not in str(base) and base.can_generate():
|
||||||
return True
|
return True
|
||||||
# BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this
|
# Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this
|
||||||
# was how we detected whether a model could generate.
|
# was how we detected whether a model could generate.
|
||||||
if "GenerationMixin" not in str(cls.prepare_inputs_for_generation):
|
if hasattr(cls, "prepare_inputs_for_generation"): # implicit: doesn't inherit `GenerationMixin`
|
||||||
logger.warning_once(
|
logger.warning(
|
||||||
f"{cls.__name__} has generative capabilities, as `prepare_inputs_for_generation` is explicitly "
|
f"{cls.__name__} has generative capabilities, as `prepare_inputs_for_generation` is explicitly "
|
||||||
"overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, "
|
"defined. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, "
|
||||||
"`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability "
|
"`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability "
|
||||||
"to call `generate` and other related functions."
|
"to call `generate` and other related functions."
|
||||||
"\n - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the "
|
"\n - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the "
|
||||||
@@ -2172,7 +2171,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
"\n - If you are not the owner of the model architecture class, please contact the model code owner "
|
"\n - If you are not the owner of the model architecture class, please contact the model code owner "
|
||||||
"to update it."
|
"to update it."
|
||||||
)
|
)
|
||||||
return True
|
|
||||||
# Otherwise, can't generate
|
# Otherwise, can't generate
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@@ -730,8 +730,12 @@ def add_generation_mixin_to_remote_model(model_class):
|
|||||||
|
|
||||||
# 3. Prior to v4.45, we could detect whether a model was `generate`-compatible if it had its own `generate` and/or
|
# 3. Prior to v4.45, we could detect whether a model was `generate`-compatible if it had its own `generate` and/or
|
||||||
# `prepare_inputs_for_generation` method.
|
# `prepare_inputs_for_generation` method.
|
||||||
has_custom_generate = "GenerationMixin" not in str(getattr(model_class, "generate"))
|
has_custom_generate = hasattr(model_class, "generate") and "GenerationMixin" not in str(
|
||||||
has_custom_prepare_inputs = "GenerationMixin" not in str(getattr(model_class, "prepare_inputs_for_generation"))
|
getattr(model_class, "generate")
|
||||||
|
)
|
||||||
|
has_custom_prepare_inputs = hasattr(model_class, "prepare_inputs_for_generation") and "GenerationMixin" not in str(
|
||||||
|
getattr(model_class, "prepare_inputs_for_generation")
|
||||||
|
)
|
||||||
if has_custom_generate or has_custom_prepare_inputs:
|
if has_custom_generate or has_custom_prepare_inputs:
|
||||||
model_class_with_generation_mixin = type(
|
model_class_with_generation_mixin = type(
|
||||||
model_class.__name__, (model_class, GenerationMixin), {**model_class.__dict__}
|
model_class.__name__, (model_class, GenerationMixin), {**model_class.__dict__}
|
||||||
|
|||||||
@@ -1512,8 +1512,8 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def can_generate(cls) -> bool:
|
def can_generate(cls) -> bool:
|
||||||
"""
|
"""
|
||||||
Legacy correction: BertForMaskedLM can't call `generate()` from GenerationMixin.
|
Legacy correction: BertForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a
|
||||||
Remove after v4.50, when we stop making `PreTrainedModel` inherit from `GenerationMixin`.
|
`prepare_inputs_for_generation` method.
|
||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@@ -1328,8 +1328,8 @@ class ErnieForMaskedLM(ErniePreTrainedModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def can_generate(cls) -> bool:
|
def can_generate(cls) -> bool:
|
||||||
"""
|
"""
|
||||||
Legacy correction: ErnieForMaskedLM can't call `generate()` from GenerationMixin.
|
Legacy correction: ErnieForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a
|
||||||
Remove after v4.50, when we stop making `PreTrainedModel` inherit from `GenerationMixin`.
|
`prepare_inputs_for_generation` method.
|
||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
from ...generation import GenerationConfig, GenerationMixin, LogitsProcessorList, StoppingCriteriaList
|
||||||
from ...modeling_outputs import ModelOutput
|
from ...modeling_outputs import ModelOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
from ...utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||||
@@ -1122,7 +1122,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||||||
""",
|
""",
|
||||||
RAG_START_DOCSTRING,
|
RAG_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class RagTokenForGeneration(RagPreTrainedModel):
|
class RagTokenForGeneration(RagPreTrainedModel, GenerationMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Optional[PretrainedConfig] = None,
|
config: Optional[PretrainedConfig] = None,
|
||||||
|
|||||||
@@ -999,6 +999,14 @@ class RemBertForMaskedLM(RemBertPreTrainedModel):
|
|||||||
|
|
||||||
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_generate(cls) -> bool:
|
||||||
|
"""
|
||||||
|
Legacy correction: RemBertForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a
|
||||||
|
`prepare_inputs_for_generation` method.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""RemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", REMBERT_START_DOCSTRING
|
"""RemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", REMBERT_START_DOCSTRING
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
|
from ...generation import GenerationMixin
|
||||||
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
from ...integrations.fsdp import is_fsdp_managed_module
|
from ...integrations.fsdp import is_fsdp_managed_module
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
||||||
@@ -2242,7 +2243,7 @@ class SpeechT5Model(SpeechT5PreTrainedModel):
|
|||||||
"""SpeechT5 Model with a speech encoder and a text decoder.""",
|
"""SpeechT5 Model with a speech encoder and a text decoder.""",
|
||||||
SPEECHT5_START_DOCSTRING,
|
SPEECHT5_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel):
|
class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["text_decoder_postnet.lm_head.weight"]
|
_tied_weights_keys = ["text_decoder_postnet.lm_head.weight"]
|
||||||
|
|
||||||
def __init__(self, config: SpeechT5Config):
|
def __init__(self, config: SpeechT5Config):
|
||||||
@@ -2413,44 +2414,6 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel):
|
|||||||
encoder_attentions=outputs.encoder_attentions,
|
encoder_attentions=outputs.encoder_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
|
||||||
self,
|
|
||||||
decoder_input_ids,
|
|
||||||
past_key_values=None,
|
|
||||||
attention_mask=None,
|
|
||||||
head_mask=None,
|
|
||||||
decoder_head_mask=None,
|
|
||||||
cross_attn_head_mask=None,
|
|
||||||
use_cache=None,
|
|
||||||
encoder_outputs=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
# Note that this model doesn't inherit from the generation mixin, has unique generate function
|
|
||||||
|
|
||||||
# cut decoder_input_ids if past is used
|
|
||||||
if past_key_values is not None:
|
|
||||||
past_length = past_key_values[0][0].shape[2]
|
|
||||||
|
|
||||||
# Some generation methods already pass only the last input ID
|
|
||||||
if decoder_input_ids.shape[1] > past_length:
|
|
||||||
remove_prefix_length = past_length
|
|
||||||
else:
|
|
||||||
# Default to old behavior: keep only final ID
|
|
||||||
remove_prefix_length = decoder_input_ids.shape[1] - 1
|
|
||||||
|
|
||||||
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"encoder_outputs": encoder_outputs,
|
|
||||||
"past_key_values": past_key_values,
|
|
||||||
"decoder_input_ids": decoder_input_ids,
|
|
||||||
"attention_mask": attention_mask,
|
|
||||||
"head_mask": head_mask,
|
|
||||||
"decoder_head_mask": decoder_head_mask,
|
|
||||||
"cross_attn_head_mask": cross_attn_head_mask,
|
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reorder_cache(past_key_values, beam_idx):
|
def _reorder_cache(past_key_values, beam_idx):
|
||||||
reordered_past = ()
|
reordered_past = ()
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from transformers.testing_utils import (
|
|||||||
from transformers.trainer_utils import set_seed
|
from transformers.trainer_utils import set_seed
|
||||||
from transformers.utils import cached_property
|
from transformers.utils import cached_property
|
||||||
|
|
||||||
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import (
|
from ...test_modeling_common import (
|
||||||
ModelTesterMixin,
|
ModelTesterMixin,
|
||||||
@@ -314,6 +315,15 @@ class SpeechT5ForSpeechToTextTester:
|
|||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_subsampled_output_lengths(self, input_lengths):
|
||||||
|
"""
|
||||||
|
Computes the output length of the convolutional layers
|
||||||
|
"""
|
||||||
|
for stride in self.conv_stride:
|
||||||
|
input_lengths = (input_lengths // stride) - 1
|
||||||
|
|
||||||
|
return input_lengths
|
||||||
|
|
||||||
def create_and_check_model_forward(self, config, inputs_dict):
|
def create_and_check_model_forward(self, config, inputs_dict):
|
||||||
model = SpeechT5ForSpeechToText(config=config).to(torch_device).eval()
|
model = SpeechT5ForSpeechToText(config=config).to(torch_device).eval()
|
||||||
|
|
||||||
@@ -359,10 +369,8 @@ class SpeechT5ForSpeechToTextTester:
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase):
|
class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase, GenerationTesterMixin):
|
||||||
all_model_classes = (SpeechT5ForSpeechToText,) if is_torch_available() else ()
|
all_model_classes = (SpeechT5ForSpeechToText,) if is_torch_available() else ()
|
||||||
# Doesn't run generation tests. TODO eustache/joao: shape checks probably need an update
|
|
||||||
all_generative_model_classes = ()
|
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
@@ -727,6 +735,18 @@ class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
|
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
|
||||||
module.masked_spec_embed.data.fill_(3)
|
module.masked_spec_embed.data.fill_(3)
|
||||||
|
|
||||||
|
@unittest.skip(reason="Temporarily broken") # TODO (joao, eustache): have a look at this test
|
||||||
|
def test_generate_with_head_masking(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Temporarily broken") # TODO (joao, eustache): have a look at this test
|
||||||
|
def test_generate_without_input_ids(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Very flaky") # TODO (joao, eustache): have a look at this test
|
||||||
|
def test_generate_continue_from_past_key_values(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
|
|||||||
@@ -1720,8 +1720,8 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
self.assertTrue("" == cl.out)
|
self.assertTrue("" == cl.out)
|
||||||
self.assertTrue(can_generate)
|
self.assertTrue(can_generate)
|
||||||
|
|
||||||
# 4 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited
|
# 4 - Legacy: models with a custom `prepare_inputs_for_generation` can generate (it was assumed
|
||||||
# `GenerationMixin`)
|
# they inherited `GenerationMixin`). Deprecated in v4.45 and removed in v4.51.
|
||||||
class DummyBertWithPrepareInputs(BertModel):
|
class DummyBertWithPrepareInputs(BertModel):
|
||||||
def prepare_inputs_for_generation(self):
|
def prepare_inputs_for_generation(self):
|
||||||
pass
|
pass
|
||||||
@@ -1729,7 +1729,7 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
with CaptureLogger(logger) as cl:
|
with CaptureLogger(logger) as cl:
|
||||||
can_generate = DummyBertWithPrepareInputs.can_generate()
|
can_generate = DummyBertWithPrepareInputs.can_generate()
|
||||||
self.assertTrue("it doesn't directly inherit from `GenerationMixin`" in cl.out)
|
self.assertTrue("it doesn't directly inherit from `GenerationMixin`" in cl.out)
|
||||||
self.assertTrue(can_generate)
|
self.assertFalse(can_generate)
|
||||||
|
|
||||||
def test_save_and_load_config_with_custom_generation(self):
|
def test_save_and_load_config_with_custom_generation(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user