🚨🚨 Fix and simplify attention implementation dispatch and subconfigs handling (#39423)

* first try

* Update modeling_utils.py

* Update modeling_utils.py

* big refactor

* Update modeling_utils.py

* style

* docstrings and simplify inner workings of configs

* remove all trace of _internal

* Update modeling_utils.py

* fix logic error

* Update modeling_utils.py

* recursive on config

* Update configuration_utils.py

* fix

* Update configuration_dpt.py

* Update configuration_utils.py

* Update configuration_utils.py

* Update modeling_idefics.py

* Update modeling_utils.py

* fix for old models

* more old models fixup

* Update modeling_utils.py

* Update configuration_utils.py

* Remove outdated test

* remove the deepcopy!! 🥵🥵

* Update test_modeling_gpt_bigcode.py

* fix qwen dispatch

* restrict to only models supporting it

* style

* switch name

* Update modeling_utils.py

* Update modeling_utils.py

* add tests!

* fix

* rypo

* remove bad copies

* fix

* Update modeling_utils.py

* additional check

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* fix

* skip
This commit is contained in:
Cyril Vallez
2025-07-18 13:41:54 +02:00
committed by GitHub
parent 2b819ba4e3
commit 4ded9a4113
33 changed files with 472 additions and 323 deletions

View File

@@ -24,6 +24,7 @@ import json
import os
import re
import shutil
import sys
import tempfile
import warnings
from abc import abstractmethod
@@ -2094,12 +2095,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
)
self.config = config
# The `hasattr` here is used as some Transformers tests for some reason do not call
# PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
if hasattr(config, "_attn_implementation_internal") and not getattr(
config, "_attn_implementation_autoset", False
):
self.set_attention_implementation(self.config._attn_implementation_internal)
# Check the attention implementation is supported, or set it if not yet set (on the internal attr, to avoid
# setting it recursively)
self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
self.config._attn_implementation, is_init_check=True
)
# for initialization of the loss
loss_type = self.__class__.__name__
@@ -2244,14 +2244,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
if torch_dtype is not None:
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config.
if config._attn_implementation_internal is not None:
# In this case, the config has been created with the attn_implementation set by the user, which we should respect.
attn_implementation = config._attn_implementation_internal
else:
attn_implementation = None
config._attn_implementation = kwargs.pop("attn_implementation", attn_implementation)
# If passing `attn_implementation` as kwargs, respect it (it will be applied recursively on subconfigs)
if "attn_implementation" in kwargs:
config._attn_implementation = kwargs.pop("attn_implementation")
if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called:
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
@@ -2272,101 +2267,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
return model
@classmethod
def _check_attn_implementation(cls, attn_implementation: Union[dict, str]) -> Union[dict, str]:
"""
Checks that the requested attention implementation exists and tries to get the kernel from hub
if `attn_implementation` matches hf kernels pattern.
"""
if isinstance(attn_implementation, str) and re.match(r"^[^/:]+/[^/:]+:[^/:]+$", attn_implementation):
if not is_kernels_available():
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
# Extract repo_id and kernel_name from the string
repo_id, kernel_name = attn_implementation.split(":")
kernel_name = kernel_name.strip()
repo_id = repo_id.strip()
try:
kernel = get_kernel(repo_id)
ALL_ATTENTION_FUNCTIONS.register(f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name))
attn_implementation = f"kernel_{repo_id.replace('/', '_')}"
except FileNotFoundError as e:
logger.warning(
f"Could not find a kernel repository '{repo_id}' compatible with your devicein the hub: {e}. Using eager attention implementation instead."
)
attn_implementation = None # try to dispatch SDPA and fallback eager if not available
except AttributeError:
raise ValueError(
"the kernel function name or class specified in the attn_implementation argument is not valid. \
Please check the documentation for the correct format, \
and check that the kernel exports the class and the function correctly."
)
if (
not isinstance(attn_implementation, dict)
and attn_implementation not in ["eager", None] + ALL_ATTENTION_FUNCTIONS.valid_keys()
):
message = f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
# check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
if cls._supports_flash_attn or getattr(cls, "_supports_flash_attn_2", False):
message += (
', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)'
', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
)
if cls._supports_sdpa:
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
if cls._supports_flex_attn:
message += ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
raise ValueError(message + ".")
return attn_implementation
def set_attention_implementation(self, attn_implementation: Union[dict, str]):
"""
Checks and dispatches to the requested attention implementation.
"""
requested_attn_implementation = self._check_attn_implementation(attn_implementation)
# Composite models consisting of several PretrainedModels can specify attention implementation as a dict where
# keys are sub-config names. But most people will specify one `str` which means that should dispatch it for all sub-models.
# See https://github.com/huggingface/transformers/pull/32238
for key in self.config.sub_configs.keys():
sub_config = getattr(self.config, key)
curr_attn_implementation = (
requested_attn_implementation
if not isinstance(requested_attn_implementation, dict)
else requested_attn_implementation.get(key, None)
)
# For models with backbone sub-config might be not initialized. Set the requested att
# if the config hasn't got any attn pre-set and the requested attn in not `None` (i.e not the default attn)
if (
sub_config is not None
and sub_config._attn_implementation_internal is None
and curr_attn_implementation is not None
):
sub_config._attn_implementation_internal = curr_attn_implementation
if requested_attn_implementation == "flash_attention_3" and self._flash_attn_3_can_dispatch():
self.config._attn_implementation = "flash_attention_3"
if requested_attn_implementation == "flash_attention_2" and self._flash_attn_2_can_dispatch():
self.config._attn_implementation = "flash_attention_2"
elif requested_attn_implementation == "flex_attention" and self._flex_attn_can_dispatch():
self.config._attn_implementation = "flex_attention"
elif (
requested_attn_implementation in [None, "sdpa"]
and not is_torch_xla_available()
and self._sdpa_can_dispatch(hard_check_only=requested_attn_implementation is not None)
):
self.config._attn_implementation = "sdpa"
elif requested_attn_implementation in ALL_ATTENTION_FUNCTIONS.valid_keys():
self.config._attn_implementation = requested_attn_implementation
elif isinstance(requested_attn_implementation, dict):
self.config._attn_implementation = requested_attn_implementation.get("", None)
else:
self.config._attn_implementation = "eager"
self.config._attn_implementation_autoset = True
@classmethod
def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
"""
@@ -2439,15 +2339,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# Otherwise, can't generate
return False
def _flash_attn_2_can_dispatch(self) -> bool:
def _flash_attn_2_can_dispatch(self, is_init_check: bool = False) -> bool:
"""
Checks the availability of Flash Attention 2 and compatibility with the current model.
Check the availability of Flash Attention 2 for a given model.
If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module.
Args:
is_init_check (`bool`, *optional*):
Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses
BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early
before instantiating the full models if we know that the model does not support the requested attention.
"""
# Config always has `torch_dtype` but we need the next line for `no_super_init()` tests
torch_dtype = self.config.torch_dtype if hasattr(self.config, "torch_dtype") else torch.get_default_dtype()
device_map = self.hf_device_map if hasattr(self, "hf_device_map") else None
torch_dtype = self.config.torch_dtype
# check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
if not (self._supports_flash_attn or getattr(self, "_supports_flash_attn_2", False)):
@@ -2486,68 +2389,62 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
else:
raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
_is_bettertransformer = getattr(self, "use_bettertransformer", False)
if _is_bettertransformer:
raise ValueError(
"Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()"
)
if torch_dtype is None:
logger.warning_once(
"You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour"
"You are attempting to use Flash Attention 2 without specifying a torch dtype. This might lead to unexpected behaviour"
)
elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
logger.warning_once(
"Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but"
"Flash Attention 2 only supports torch.float16 and torch.bfloat16 dtypes, but"
f" the current dype in {self.__class__.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`'
)
# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
# or the model may be initialized under the context manager `with torch.device("cuda"):`.
if device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]:
if torch.cuda.is_available():
logger.warning_once(
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU"
" after initializing it on CPU with `model.to('cuda')`."
)
elif is_torch_mlu_available():
logger.warning_once(
"You are attempting to use Flash Attention 2.0 with a model not initialized on MLU. Make sure to move the model to MLU"
" after initializing it on CPU with `model.to('mlu')`."
)
else:
# With the early check, the parameters are not yet initalized correctly
if not is_init_check:
if getattr(self, "use_bettertransformer", False):
raise ValueError(
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU and with no GPU available. "
"This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
"or initialising the model on CPU and then moving it to GPU."
"Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()"
)
elif (
device_map is not None
and isinstance(device_map, dict)
and ("cpu" in device_map.values() or "disk" in device_map.values())
):
raise ValueError(
"You are attempting to use Flash Attention 2.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to "
"initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
)
param_devices = list({param.device for param in self.parameters()})
if len(param_devices) == 1 and param_devices[0].type == "cpu":
if torch.cuda.is_available():
logger.warning_once(
"You are attempting to use Flash Attention 2 with a model not initialized on GPU. Make sure to move the model to GPU"
" after initializing it on CPU with `model.to('cuda')`."
)
elif is_torch_mlu_available():
logger.warning_once(
"You are attempting to use Flash Attention 2 with a model not initialized on MLU. Make sure to move the model to MLU"
" after initializing it on CPU with `model.to('mlu')`."
)
else:
raise ValueError(
"You are attempting to use Flash Attention 2 with a model not initialized on GPU and with no GPU available. "
"This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
"or initialising the model on CPU and then moving it to GPU."
)
# If no error raise by this point, we can return `True`
return True
def _flash_attn_3_can_dispatch(self) -> bool:
def _flash_attn_3_can_dispatch(self, is_init_check: bool = False) -> bool:
"""
Checks the availability of Flash Attention 3 and compatibility with the current model.
Check the availability of Flash Attention 3 for a given model.
If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_3" so that the model can initialize the correct attention module.
Args:
is_init_check (`bool`, *optional*):
Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses
BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early
before instantiating the full models if we know that the model does not support the requested attention.
"""
# Config always has `torch_dtype` but we need the next line for `no_super_init()` tests
torch_dtype = self.config.torch_dtype if hasattr(self.config, "torch_dtype") else torch.get_default_dtype()
device_map = self.hf_device_map if hasattr(self, "hf_device_map") else None
torch_dtype = self.config.torch_dtype
if not self._supports_flash_attn:
raise ValueError(
f"{self.__class__.__name__} does not support Flash Attention 3.0 yet. Please request to add support where"
f"{self.__class__.__name__} does not support Flash Attention 3 yet. Please request to add support where"
f" the model is hosted, on its model hub page: https://huggingface.co/{self.config._name_or_path}/discussions/new"
" or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new"
)
@@ -2591,48 +2488,43 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
f"Model has attention_dropout={self.config.attention_dropout}, which is not supported by Flash Attention 3."
)
# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
# or the model may be initialized under the context manager `with torch.device("cuda"):`.
if device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]:
if torch.cuda.is_available():
logger.warning_once(
"You are attempting to use Flash Attention 3 with a model not initialized on GPU. Make sure to move the model to GPU"
" after initializing it on CPU with `model.to('cuda')`."
)
else:
raise ValueError(
"You are attempting to use Flash Attention 3 with a model not initialized on GPU and with no GPU available. "
"This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
"or initialising the model on CPU and then moving it to GPU."
)
elif (
device_map is not None
and isinstance(device_map, dict)
and ("cpu" in device_map.values() or "disk" in device_map.values())
):
raise ValueError(
"You are attempting to use Flash Attention 3 with a model dispatched on CPU or disk. This is not supported. Please make sure to "
"initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
)
# With the early check, the parameters are not yet initalized correctly
if not is_init_check:
param_devices = list({param.device for param in self.parameters()})
if len(param_devices) == 1 and param_devices[0].type == "cpu":
if torch.cuda.is_available():
logger.warning_once(
"You are attempting to use Flash Attention 3 with a model not initialized on GPU. Make sure to move the model to GPU"
" after initializing it on CPU with `model.to('cuda')`."
)
else:
raise ValueError(
"You are attempting to use Flash Attention 3 with a model not initialized on GPU and with no GPU available. "
"This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
"or initialising the model on CPU and then moving it to GPU."
)
return True
def _sdpa_can_dispatch(self, hard_check_only: bool = False) -> bool:
def _sdpa_can_dispatch(self, is_init_check: bool = False) -> bool:
"""
Checks the availability of SDPA for a given model.
Check the availability of SDPA for a given model.
If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "sdpa" so that the model can initialize the correct attention module.
Args:
is_init_check (`bool`, *optional*):
Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses
BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early
before instantiating the full models if we know that the model does not support the requested attention.
"""
if hard_check_only:
if not self._supports_sdpa:
raise ValueError(
f"{self.__class__.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet."
" Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe"
' this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
)
if not is_torch_sdpa_available():
raise ImportError(
"PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1."
)
if not self._supports_sdpa:
raise ValueError(
f"{self.__class__.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet."
" Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe"
' this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
)
if not is_torch_sdpa_available():
raise ImportError("PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1.")
if (
torch.version.hip is not None
@@ -2644,18 +2536,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
)
torch.backends.cuda.enable_flash_sdp(False)
# This means we have `hard_check_only=False` and fallback to eager if SDPA isn't supported
_is_bettertransformer = getattr(self, "use_bettertransformer", False)
if not is_torch_sdpa_available() or not self._supports_sdpa or _is_bettertransformer:
return False
if not is_init_check:
if getattr(self, "use_bettertransformer", False):
raise ValueError(
"SDPA and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()"
)
return True
def _flex_attn_can_dispatch(self) -> bool:
def _flex_attn_can_dispatch(self, is_init_check: bool = False) -> bool:
"""
Checks the availability of Flex Attention for a given model.
Check the availability of Flex Attention for a given model.
If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flex_attention" so that the model can initialize the correct attention module.
Args:
is_init_check (`bool`, *optional*):
Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses
BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early
before instantiating the full models if we know that the model does not support the requested attention.
"""
if not self._supports_flex_attn:
raise ValueError(
@@ -2670,9 +2568,190 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
"PyTorch Flex Attention requirements in Transformers are not met. Please install torch>=2.5.0."
)
if not is_init_check:
if getattr(self, "use_bettertransformer", False):
raise ValueError(
"FlexAttention and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()"
)
# If no error raise by this point, we can return `True`
return True
def _check_and_adjust_attn_implementation(
self, attn_implementation: Optional[str], is_init_check: bool = False
) -> str:
"""
Check that the `attn_implementation` exists and is supported by the models, and try to get the kernel from hub if
it matches hf kernels pattern.
Args:
attn_implementation (`str` or `None`):
The attention implementation to check for existence/validity.
is_init_check (`bool`, *optional*):
Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses
BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early
before instantiating the full models if we know that the model does not support the requested attention.
Returns:
`str`: The final attention implementation to use, including potential fallbacks from sdpa to eager, or from
None to sdpa (to potentially eager).
"""
applicable_attn_implementation = "sdpa" if attn_implementation is None else attn_implementation
if re.match(r"^[^/:]+/[^/:]+:[^/:]+$", applicable_attn_implementation):
if not is_kernels_available():
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
# Extract repo_id and kernel_name from the string
repo_id, kernel_name = applicable_attn_implementation.split(":")
kernel_name = kernel_name.strip()
repo_id = repo_id.strip()
try:
kernel = get_kernel(repo_id)
ALL_ATTENTION_FUNCTIONS.register(f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name))
applicable_attn_implementation = f"kernel_{repo_id.replace('/', '_')}"
except FileNotFoundError as e:
logger.warning_once(
f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using "
"default attention implementation instead (sdpa if available, eager otherwise)."
)
applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case
except AttributeError:
raise ValueError(
"the kernel function name or class specified in the attn_implementation argument is not valid. Please check "
"the documentation for the correct format, and check that the kernel exports the class and the function correctly."
)
if applicable_attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
message = (
f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are '
'`attn_implementation="eager"` (manual attention implementation)'
)
# check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
if self._supports_flash_attn or getattr(self, "_supports_flash_attn_2", False):
message += (
', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)'
', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
)
if self._supports_sdpa:
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
if self._supports_flex_attn:
message += ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
raise ValueError(message + ".")
# Perform relevant checks
if applicable_attn_implementation == "flash_attention_2":
self._flash_attn_2_can_dispatch(is_init_check)
elif applicable_attn_implementation == "flash_attention_3":
self._flash_attn_3_can_dispatch(is_init_check)
elif applicable_attn_implementation == "flex_attention":
self._flex_attn_can_dispatch(is_init_check)
elif applicable_attn_implementation == "sdpa":
# Sdpa is the default, so we try it and fallback to eager otherwise when not possible
try:
self._sdpa_can_dispatch(is_init_check)
except (ValueError, ImportError) as e:
# In this case, sdpa was requested explicitly, but we can't use it, so let's raise
if attn_implementation == "sdpa":
raise e
applicable_attn_implementation = "eager"
return applicable_attn_implementation
@classmethod
def _can_set_attn_implementation(cls) -> bool:
"""Detect whether the class supports setting its attention implementation dynamically. It is an ugly check based on
opening the file, but avoids maintaining yet another property flag.
"""
class_file = sys.modules[cls.__module__].__file__
with open(class_file, "r") as f:
code = f.read()
# heuristic -> if we find those patterns, the model uses the correct interface
return (
"eager_attention_forward" in code and "ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]" in code
)
def set_attn_implementation(self, attn_implementation: Union[str, dict]):
"""
Set the requested `attn_implementation` for this model.
Args:
attn_implementation (`str` or `dict`):
The attention implementation to set for this model. It can be either a `str`, in which case it will be
dispatched to all submodels if relevant, or a `dict` where keys are the sub_configs name, in which case each
submodel will dispatch the corresponding value.
"""
requested_implementation = (
attn_implementation
if not isinstance(attn_implementation, dict)
else attn_implementation.get("", self.config._attn_implementation)
)
# At this point, the model was already instantiated, so instead of crashing on bad value, let's simply
# warn the user that the requested value is not working
if requested_implementation != self.config._attn_implementation:
# In this case, raise
if not self._can_set_attn_implementation():
logger.warning(
f"{self.__class__.__name__} does not support setting its attention implementation dynamically, because it "
"does not follow the functional approach based on AttentionInterface "
"(see https://huggingface.co/docs/transformers/en/attention_interface)"
)
else:
try:
applicable_attn_implementation = self._check_and_adjust_attn_implementation(
requested_implementation, is_init_check=False
)
# Apply the change (on the internal attr, to avoid setting it recursively)
self.config._attn_implementation_internal = applicable_attn_implementation
except (ValueError, ImportError) as e:
logger.warning(
f"Impossible to set the requested `attn_implementation`. The following error was captured: {str(e)}"
)
subconfigs_changed = set()
# Apply it to all submodels as well
for submodule in self.modules():
# We found a submodel (which is not self) with a different config (otherwise, it may be the same "actual model",
# e.g. ForCausalLM has a Model inside, but no need to check it again)
if (
submodule is not self
and isinstance(submodule, PreTrainedModel)
and submodule.config.__class__ != self.config.__class__
):
sub_implementation = attn_implementation
if isinstance(attn_implementation, dict):
for subconfig_key in self.config.sub_configs:
# We need to check for exact object match here, with `is`
if getattr(self.config, subconfig_key) is submodule.config:
sub_implementation = attn_implementation.get(
subconfig_key, submodule.config._attn_implementation
)
break
submodule.set_attn_implementation(sub_implementation)
subconfigs_changed.add(submodule.config.__class__)
# We need this as some old and badly designed models use subconfigs without declaring the corresponding modules as PreTrainedModel
for subconfig_key in self.config.sub_configs:
subconfig = getattr(self.config, subconfig_key)
requested_implementation = (
attn_implementation
if not isinstance(attn_implementation, dict)
else attn_implementation.get(subconfig_key, subconfig._attn_implementation)
)
# This means we did not perform any check above for this particular subconfig -> set it in the dark if it is registered
if (
subconfig.__class__ not in subconfigs_changed
and requested_implementation != subconfig._attn_implementation
and requested_implementation in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys()
):
subconfig._attn_implementation_internal = requested_implementation
logger.warning(
f"We set the attention implementation for the sub-config `{subconfig_key}` to `{requested_implementation}` "
"without finding the associated sub-model. For this reason we could not check if the model supports it. "
"You may encounter undefined behavior."
)
def enable_input_require_grads(self):
"""
Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
@@ -4601,21 +4680,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
if "gguf_file" in model_kwargs:
model_kwargs.pop("gguf_file")
else:
# In case one passes a config to `from_pretrained` + "attn_implementation"
# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs
# Please see: https://github.com/huggingface/transformers/issues/28038
# Overwrite `config._attn_implementation` by the one from the kwargs --> in auto-factory
# we pop attn_implementation from the kwargs but this handles the case where users
# passes manually the config to `from_pretrained`.
config = copy.deepcopy(config)
kwarg_attn_imp = kwargs.pop("attn_implementation", None)
if kwarg_attn_imp is not None:
config._attn_implementation = kwarg_attn_imp
model_kwargs = kwargs
# Because some composite configs call super().__init__ before instantiating the sub-configs, we need this call
# to correctly redispatch recursively if the kwarg is provided
if "attn_implementation" in kwargs:
config._attn_implementation = kwargs.pop("attn_implementation")
transformers_explicit_filename = getattr(config, "transformers_weights", None)
if transformers_explicit_filename is not None: