Refactor embedding input/output getter/setter (#39339)
* simplify common get/set
* remove some noise
* change some 5 years old modeling utils
* update examples
* fix copies
* revert some changes
* fixes, gah
* format
* move to Mixin
* remove smolvlm specific require grad
* skip
* force defaults
* remodularise some stuff
* remodularise more stuff
* add safety for audio models
* style
* have a correct fallback, you daft donkey
* remove this argh
* change heuristic for audio models
* fixup
* revert
* this works
* revert again
* 🧠
* aaah ESM has two modelings aaah
* add informative but short comment
* add `input_embed_layer` mixin attribute
* style
* walrus has low precedence
* modular fix
* this was breaking parser
This commit is contained in:
@@ -1902,7 +1902,97 @@ class ModuleUtilsMixin:
|
||||
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
|
||||
|
||||
|
||||
class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMixin):
|
||||
class EmbeddingAccessMixin:
|
||||
"""
|
||||
Base utilities to regroup getters and setters for embeddings.
|
||||
Introduces the `input_layer_embed` attribute, which indicates
|
||||
where the input embeddings come from and where they
|
||||
should be set.
|
||||
"""
|
||||
|
||||
_input_embed_layer = "embed_tokens" # default layer that holds input embeddings.
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
"""
|
||||
Returns the model's input embeddings.
|
||||
|
||||
Returns:
|
||||
`nn.Module`: A torch module mapping vocabulary to hidden states.
|
||||
"""
|
||||
|
||||
# 1) Check if the model has an attribute named 'embed_tokens' (the standard input embedding layer
|
||||
# for most NLP models), and if so, return it.
|
||||
|
||||
name = getattr(self, "_input_embed_layer", "embed_tokens")
|
||||
|
||||
if (default_embedding := getattr(self, name, None)) is not None:
|
||||
return default_embedding
|
||||
# 2) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
|
||||
|
||||
if hasattr(self, "model") and hasattr(self.model, "embed_tokens"):
|
||||
return self.model.embed_tokens
|
||||
|
||||
# 3) vanilla decoder‑only architectures
|
||||
elif hasattr(self, "embed_tokens"):
|
||||
return self.embed_tokens
|
||||
else:
|
||||
base_model = getattr(self, "base_model_prefix", None)
|
||||
if base_model is not None:
|
||||
base_model = getattr(self, base_model, None)
|
||||
if base_model is not None and base_model is not self:
|
||||
return base_model.get_input_embeddings()
|
||||
raise NotImplementedError(
|
||||
f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; "
|
||||
"please override in the subclass."
|
||||
)
|
||||
|
||||
def set_input_embeddings(self, value: nn.Module):
|
||||
"""Fallback setter that handles **~70 %** of models in the code‑base.
|
||||
|
||||
Order of attempts:
|
||||
1. `self.model.embed_tokens`
|
||||
2. `self.embed_tokens`
|
||||
3. delegate to the *base model* if one exists
|
||||
4. otherwise raise `NotImplementedError` so subclasses still can (and
|
||||
should) override for exotic layouts.
|
||||
"""
|
||||
|
||||
# 1) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
|
||||
name = getattr(self, "_input_embed_layer", "embed_tokens")
|
||||
if hasattr(self, "model") and hasattr(self.model, name):
|
||||
setattr(self.model, name, value)
|
||||
# 2) as well as vanilla decoder‑only architectures
|
||||
elif hasattr(self, name):
|
||||
setattr(self, name, value)
|
||||
# 3) recurse once into the registered *base* model (e.g. for encoder/decoder)
|
||||
elif getattr(self, self.base_model_prefix, self) is not self:
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
base_model.set_input_embeddings(value)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"`set_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass."
|
||||
)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
if not hasattr(self, "lm_head"):
|
||||
return None
|
||||
try:
|
||||
# Speech / vision backbones raise here, so we return None.
|
||||
# Legit use of get_input_embs?
|
||||
self.get_input_embeddings()
|
||||
except NotImplementedError:
|
||||
return None
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
"""
|
||||
Sets the model's output embedding, defaulting to setting new_embeddings to lm_head.
|
||||
"""
|
||||
if getattr(self, "lm_head"):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
|
||||
class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMixin):
|
||||
r"""
|
||||
Base class for all models.
|
||||
|
||||
@@ -2004,6 +2094,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
_supports_attention_backend = False
|
||||
_can_record_outputs = None
|
||||
|
||||
# This attribute sets the default parameter to be
|
||||
|
||||
@property
|
||||
@torch._dynamo.allow_in_graph
|
||||
def can_record_outputs(self) -> dict[str, OutputRecorder]:
|
||||
@@ -2267,6 +2359,101 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def _check_attn_implementation(cls, attn_implementation: Union[str, dict]) -> Union[str, dict]:
|
||||
"""
|
||||
Checks that the requested attention implementation exists and tries to get the kernel from hub
|
||||
if `attn_implementation` matches hf kernels pattern.
|
||||
"""
|
||||
if isinstance(attn_implementation, str) and re.match(r"^[^/:]+/[^/:]+:[^/:]+$", attn_implementation):
|
||||
if not is_kernels_available():
|
||||
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
|
||||
|
||||
# Extract repo_id and kernel_name from the string
|
||||
repo_id, kernel_name = attn_implementation.split(":")
|
||||
kernel_name = kernel_name.strip()
|
||||
repo_id = repo_id.strip()
|
||||
|
||||
try:
|
||||
kernel = get_kernel(repo_id)
|
||||
ALL_ATTENTION_FUNCTIONS.register(f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name))
|
||||
attn_implementation = f"kernel_{repo_id.replace('/', '_')}"
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(
|
||||
f"Could not find a kernel repository '{repo_id}' compatible with your devicein the hub: {e}. Using eager attention implementation instead."
|
||||
)
|
||||
attn_implementation = None # try to dispatch SDPA and fallback eager if not available
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"the kernel function name or class specified in the attn_implementation argument is not valid. \
|
||||
Please check the documentation for the correct format, \
|
||||
and check that the kernel exports the class and the function correctly."
|
||||
)
|
||||
if (
|
||||
not isinstance(attn_implementation, dict)
|
||||
and attn_implementation not in ["eager", None] + ALL_ATTENTION_FUNCTIONS.valid_keys()
|
||||
):
|
||||
message = f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
|
||||
# check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
|
||||
if cls._supports_flash_attn or getattr(cls, "_supports_flash_attn_2", False):
|
||||
message += (
|
||||
', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)'
|
||||
', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
|
||||
)
|
||||
if cls._supports_sdpa:
|
||||
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
|
||||
if cls._supports_flex_attn:
|
||||
message += ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
|
||||
raise ValueError(message + ".")
|
||||
|
||||
return attn_implementation
|
||||
|
||||
def set_attention_implementation(self, attn_implementation: Union[str, dict]):
|
||||
"""
|
||||
Checks and dispatches to the requested attention implementation.
|
||||
"""
|
||||
requested_attn_implementation = self._check_attn_implementation(attn_implementation)
|
||||
|
||||
# Composite models consisting of several PretrainedModels can specify attention implementation as a dict where
|
||||
# keys are sub-config names. But most people will specify one `str` which means that should dispatch it for all sub-models.
|
||||
# See https://github.com/huggingface/transformers/pull/32238
|
||||
for key in self.config.sub_configs.keys():
|
||||
sub_config = getattr(self.config, key)
|
||||
curr_attn_implementation = (
|
||||
requested_attn_implementation
|
||||
if not isinstance(requested_attn_implementation, dict)
|
||||
else requested_attn_implementation.get(key, None)
|
||||
)
|
||||
# For models with backbone sub-config might be not initialized. Set the requested att
|
||||
# if the config hasn't got any attn pre-set and the requested attn in not `None` (i.e not the default attn)
|
||||
if (
|
||||
sub_config is not None
|
||||
and sub_config._attn_implementation_internal is None
|
||||
and curr_attn_implementation is not None
|
||||
):
|
||||
sub_config._attn_implementation_internal = curr_attn_implementation
|
||||
|
||||
if requested_attn_implementation == "flash_attention_3" and self._flash_attn_3_can_dispatch():
|
||||
self.config._attn_implementation = "flash_attention_3"
|
||||
if requested_attn_implementation == "flash_attention_2" and self._flash_attn_2_can_dispatch():
|
||||
self.config._attn_implementation = "flash_attention_2"
|
||||
elif requested_attn_implementation == "flex_attention" and self._flex_attn_can_dispatch():
|
||||
self.config._attn_implementation = "flex_attention"
|
||||
elif (
|
||||
requested_attn_implementation in [None, "sdpa"]
|
||||
and not is_torch_xla_available()
|
||||
and self._sdpa_can_dispatch(hard_check_only=requested_attn_implementation is not None)
|
||||
):
|
||||
self.config._attn_implementation = "sdpa"
|
||||
elif requested_attn_implementation in ALL_ATTENTION_FUNCTIONS.valid_keys():
|
||||
self.config._attn_implementation = requested_attn_implementation
|
||||
elif isinstance(requested_attn_implementation, dict):
|
||||
self.config._attn_implementation = requested_attn_implementation.get("", None)
|
||||
else:
|
||||
self.config._attn_implementation = "eager"
|
||||
|
||||
self.config._attn_implementation_autoset = True
|
||||
|
||||
@classmethod
|
||||
def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
|
||||
"""
|
||||
@@ -2769,41 +2956,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
"""
|
||||
self._require_grads_hook.remove()
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
"""
|
||||
Returns the model's input embeddings.
|
||||
|
||||
Returns:
|
||||
`nn.Module`: A torch module mapping vocabulary to hidden states.
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
if base_model is not self:
|
||||
return base_model.get_input_embeddings()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def set_input_embeddings(self, value: nn.Module):
|
||||
"""
|
||||
Set model's input embeddings.
|
||||
|
||||
Args:
|
||||
value (`nn.Module`): A module mapping vocabulary to hidden states.
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
if base_model is not self:
|
||||
base_model.set_input_embeddings(value)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_output_embeddings(self) -> nn.Module:
|
||||
"""
|
||||
Returns the model's output embeddings.
|
||||
|
||||
Returns:
|
||||
`nn.Module`: A torch module mapping hidden states to vocabulary.
|
||||
"""
|
||||
return None # Overwrite for models with output embeddings
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""
|
||||
Initialize the weights. This method should be overridden by derived class and is
|
||||
|
||||
Reference in New Issue
Block a user