Refactor embedding input/output getter/setter (#39339)

* simplify common get/set

* remove some noise

* change some 5 years old modeling utils

* update examples

* fix copies

* revert some changes

* fixes, gah

* format

* move to Mixin

* remove smolvlm specific require grad

* skip

* force defaults

* remodularise some stuff

* remodularise more stuff

* add safety for audio models

* style

* have a correct fallback, you daft donkey

* remove this argh

* change heuristic for audio models

* fixup

* revert

* this works

* revert again

* 🧠

* aaah ESM has two modelings aaah

* add informative but short comment

* add `input_embed_layer` mixin attribute

* style

* walrus has low precedence

* modular fix

* this was breaking parser
This commit is contained in:
Pablo Montalvo
2025-07-21 18:18:14 +02:00
committed by GitHub
parent 2da97f0943
commit 69b158260f
163 changed files with 235 additions and 2388 deletions

View File

@@ -1902,7 +1902,97 @@ class ModuleUtilsMixin:
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMixin):
class EmbeddingAccessMixin:
"""
Base utilities to regroup getters and setters for embeddings.
Introduces the `input_layer_embed` attribute, which indicates
where the input embeddings come from and where they
should be set.
"""
_input_embed_layer = "embed_tokens" # default layer that holds input embeddings.
def get_input_embeddings(self) -> nn.Module:
"""
Returns the model's input embeddings.
Returns:
`nn.Module`: A torch module mapping vocabulary to hidden states.
"""
# 1) Check if the model has an attribute named 'embed_tokens' (the standard input embedding layer
# for most NLP models), and if so, return it.
name = getattr(self, "_input_embed_layer", "embed_tokens")
if (default_embedding := getattr(self, name, None)) is not None:
return default_embedding
# 2) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
if hasattr(self, "model") and hasattr(self.model, "embed_tokens"):
return self.model.embed_tokens
# 3) vanilla decoderonly architectures
elif hasattr(self, "embed_tokens"):
return self.embed_tokens
else:
base_model = getattr(self, "base_model_prefix", None)
if base_model is not None:
base_model = getattr(self, base_model, None)
if base_model is not None and base_model is not self:
return base_model.get_input_embeddings()
raise NotImplementedError(
f"`get_input_embeddings` not autohandled for {self.__class__.__name__}; "
"please override in the subclass."
)
def set_input_embeddings(self, value: nn.Module):
"""Fallback setter that handles **~70%** of models in the codebase.
Order of attempts:
1. `self.model.embed_tokens`
2. `self.embed_tokens`
3. delegate to the *base model* if one exists
4. otherwise raise `NotImplementedError` so subclasses still can (and
should) override for exotic layouts.
"""
# 1) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
name = getattr(self, "_input_embed_layer", "embed_tokens")
if hasattr(self, "model") and hasattr(self.model, name):
setattr(self.model, name, value)
# 2) as well as vanilla decoderonly architectures
elif hasattr(self, name):
setattr(self, name, value)
# 3) recurse once into the registered *base* model (e.g. for encoder/decoder)
elif getattr(self, self.base_model_prefix, self) is not self:
base_model = getattr(self, self.base_model_prefix, self)
base_model.set_input_embeddings(value)
else:
raise NotImplementedError(
f"`set_input_embeddings` not autohandled for {self.__class__.__name__}; please override in the subclass."
)
def get_output_embeddings(self):
if not hasattr(self, "lm_head"):
return None
try:
# Speech / vision backbones raise here, so we return None.
# Legit use of get_input_embs?
self.get_input_embeddings()
except NotImplementedError:
return None
return self.lm_head
def set_output_embeddings(self, new_embeddings):
"""
Sets the model's output embedding, defaulting to setting new_embeddings to lm_head.
"""
if getattr(self, "lm_head"):
self.lm_head = new_embeddings
class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMixin):
r"""
Base class for all models.
@@ -2004,6 +2094,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
_supports_attention_backend = False
_can_record_outputs = None
# This attribute sets the default parameter to be
@property
@torch._dynamo.allow_in_graph
def can_record_outputs(self) -> dict[str, OutputRecorder]:
@@ -2267,6 +2359,101 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
return model
@classmethod
def _check_attn_implementation(cls, attn_implementation: Union[str, dict]) -> Union[str, dict]:
"""
Checks that the requested attention implementation exists and tries to get the kernel from hub
if `attn_implementation` matches hf kernels pattern.
"""
if isinstance(attn_implementation, str) and re.match(r"^[^/:]+/[^/:]+:[^/:]+$", attn_implementation):
if not is_kernels_available():
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
# Extract repo_id and kernel_name from the string
repo_id, kernel_name = attn_implementation.split(":")
kernel_name = kernel_name.strip()
repo_id = repo_id.strip()
try:
kernel = get_kernel(repo_id)
ALL_ATTENTION_FUNCTIONS.register(f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name))
attn_implementation = f"kernel_{repo_id.replace('/', '_')}"
except FileNotFoundError as e:
logger.warning(
f"Could not find a kernel repository '{repo_id}' compatible with your devicein the hub: {e}. Using eager attention implementation instead."
)
attn_implementation = None # try to dispatch SDPA and fallback eager if not available
except AttributeError:
raise ValueError(
"the kernel function name or class specified in the attn_implementation argument is not valid. \
Please check the documentation for the correct format, \
and check that the kernel exports the class and the function correctly."
)
if (
not isinstance(attn_implementation, dict)
and attn_implementation not in ["eager", None] + ALL_ATTENTION_FUNCTIONS.valid_keys()
):
message = f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
# check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
if cls._supports_flash_attn or getattr(cls, "_supports_flash_attn_2", False):
message += (
', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)'
', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
)
if cls._supports_sdpa:
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
if cls._supports_flex_attn:
message += ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
raise ValueError(message + ".")
return attn_implementation
def set_attention_implementation(self, attn_implementation: Union[str, dict]):
"""
Checks and dispatches to the requested attention implementation.
"""
requested_attn_implementation = self._check_attn_implementation(attn_implementation)
# Composite models consisting of several PretrainedModels can specify attention implementation as a dict where
# keys are sub-config names. But most people will specify one `str` which means that should dispatch it for all sub-models.
# See https://github.com/huggingface/transformers/pull/32238
for key in self.config.sub_configs.keys():
sub_config = getattr(self.config, key)
curr_attn_implementation = (
requested_attn_implementation
if not isinstance(requested_attn_implementation, dict)
else requested_attn_implementation.get(key, None)
)
# For models with backbone sub-config might be not initialized. Set the requested att
# if the config hasn't got any attn pre-set and the requested attn in not `None` (i.e not the default attn)
if (
sub_config is not None
and sub_config._attn_implementation_internal is None
and curr_attn_implementation is not None
):
sub_config._attn_implementation_internal = curr_attn_implementation
if requested_attn_implementation == "flash_attention_3" and self._flash_attn_3_can_dispatch():
self.config._attn_implementation = "flash_attention_3"
if requested_attn_implementation == "flash_attention_2" and self._flash_attn_2_can_dispatch():
self.config._attn_implementation = "flash_attention_2"
elif requested_attn_implementation == "flex_attention" and self._flex_attn_can_dispatch():
self.config._attn_implementation = "flex_attention"
elif (
requested_attn_implementation in [None, "sdpa"]
and not is_torch_xla_available()
and self._sdpa_can_dispatch(hard_check_only=requested_attn_implementation is not None)
):
self.config._attn_implementation = "sdpa"
elif requested_attn_implementation in ALL_ATTENTION_FUNCTIONS.valid_keys():
self.config._attn_implementation = requested_attn_implementation
elif isinstance(requested_attn_implementation, dict):
self.config._attn_implementation = requested_attn_implementation.get("", None)
else:
self.config._attn_implementation = "eager"
self.config._attn_implementation_autoset = True
@classmethod
def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
"""
@@ -2769,41 +2956,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
"""
self._require_grads_hook.remove()
def get_input_embeddings(self) -> nn.Module:
"""
Returns the model's input embeddings.
Returns:
`nn.Module`: A torch module mapping vocabulary to hidden states.
"""
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
return base_model.get_input_embeddings()
else:
raise NotImplementedError
def set_input_embeddings(self, value: nn.Module):
"""
Set model's input embeddings.
Args:
value (`nn.Module`): A module mapping vocabulary to hidden states.
"""
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
base_model.set_input_embeddings(value)
else:
raise NotImplementedError
def get_output_embeddings(self) -> nn.Module:
"""
Returns the model's output embeddings.
Returns:
`nn.Module`: A torch module mapping hidden states to vocabulary.
"""
return None # Overwrite for models with output embeddings
def _init_weights(self, module):
"""
Initialize the weights. This method should be overridden by derived class and is