🚨🚨 Fix and simplify attention implementation dispatch and subconfigs handling (#39423)
* first try * Update modeling_utils.py * Update modeling_utils.py * big refactor * Update modeling_utils.py * style * docstrings and simplify inner workings of configs * remove all trace of _internal * Update modeling_utils.py * fix logic error * Update modeling_utils.py * recursive on config * Update configuration_utils.py * fix * Update configuration_dpt.py * Update configuration_utils.py * Update configuration_utils.py * Update modeling_idefics.py * Update modeling_utils.py * fix for old models * more old models fixup * Update modeling_utils.py * Update configuration_utils.py * Remove outdated test * remove the deepcopy!! 🥵🥵 * Update test_modeling_gpt_bigcode.py * fix qwen dispatch * restrict to only models supporting it * style * switch name * Update modeling_utils.py * Update modeling_utils.py * add tests! * fix * rypo * remove bad copies * fix * Update modeling_utils.py * additional check * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * fix * skip
This commit is contained in:
@@ -24,6 +24,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
@@ -2094,12 +2095,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
)
|
||||
self.config = config
|
||||
|
||||
# The `hasattr` here is used as some Transformers tests for some reason do not call
|
||||
# PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
|
||||
if hasattr(config, "_attn_implementation_internal") and not getattr(
|
||||
config, "_attn_implementation_autoset", False
|
||||
):
|
||||
self.set_attention_implementation(self.config._attn_implementation_internal)
|
||||
# Check the attention implementation is supported, or set it if not yet set (on the internal attr, to avoid
|
||||
# setting it recursively)
|
||||
self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
|
||||
self.config._attn_implementation, is_init_check=True
|
||||
)
|
||||
|
||||
# for initialization of the loss
|
||||
loss_type = self.__class__.__name__
|
||||
@@ -2244,14 +2244,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
if torch_dtype is not None:
|
||||
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
||||
|
||||
config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config.
|
||||
|
||||
if config._attn_implementation_internal is not None:
|
||||
# In this case, the config has been created with the attn_implementation set by the user, which we should respect.
|
||||
attn_implementation = config._attn_implementation_internal
|
||||
else:
|
||||
attn_implementation = None
|
||||
config._attn_implementation = kwargs.pop("attn_implementation", attn_implementation)
|
||||
# If passing `attn_implementation` as kwargs, respect it (it will be applied recursively on subconfigs)
|
||||
if "attn_implementation" in kwargs:
|
||||
config._attn_implementation = kwargs.pop("attn_implementation")
|
||||
|
||||
if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called:
|
||||
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
||||
@@ -2272,101 +2267,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def _check_attn_implementation(cls, attn_implementation: Union[dict, str]) -> Union[dict, str]:
|
||||
"""
|
||||
Checks that the requested attention implementation exists and tries to get the kernel from hub
|
||||
if `attn_implementation` matches hf kernels pattern.
|
||||
"""
|
||||
if isinstance(attn_implementation, str) and re.match(r"^[^/:]+/[^/:]+:[^/:]+$", attn_implementation):
|
||||
if not is_kernels_available():
|
||||
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
|
||||
|
||||
# Extract repo_id and kernel_name from the string
|
||||
repo_id, kernel_name = attn_implementation.split(":")
|
||||
kernel_name = kernel_name.strip()
|
||||
repo_id = repo_id.strip()
|
||||
|
||||
try:
|
||||
kernel = get_kernel(repo_id)
|
||||
ALL_ATTENTION_FUNCTIONS.register(f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name))
|
||||
attn_implementation = f"kernel_{repo_id.replace('/', '_')}"
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(
|
||||
f"Could not find a kernel repository '{repo_id}' compatible with your devicein the hub: {e}. Using eager attention implementation instead."
|
||||
)
|
||||
attn_implementation = None # try to dispatch SDPA and fallback eager if not available
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"the kernel function name or class specified in the attn_implementation argument is not valid. \
|
||||
Please check the documentation for the correct format, \
|
||||
and check that the kernel exports the class and the function correctly."
|
||||
)
|
||||
if (
|
||||
not isinstance(attn_implementation, dict)
|
||||
and attn_implementation not in ["eager", None] + ALL_ATTENTION_FUNCTIONS.valid_keys()
|
||||
):
|
||||
message = f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
|
||||
# check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
|
||||
if cls._supports_flash_attn or getattr(cls, "_supports_flash_attn_2", False):
|
||||
message += (
|
||||
', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)'
|
||||
', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
|
||||
)
|
||||
if cls._supports_sdpa:
|
||||
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
|
||||
if cls._supports_flex_attn:
|
||||
message += ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
|
||||
raise ValueError(message + ".")
|
||||
|
||||
return attn_implementation
|
||||
|
||||
def set_attention_implementation(self, attn_implementation: Union[dict, str]):
|
||||
"""
|
||||
Checks and dispatches to the requested attention implementation.
|
||||
"""
|
||||
requested_attn_implementation = self._check_attn_implementation(attn_implementation)
|
||||
|
||||
# Composite models consisting of several PretrainedModels can specify attention implementation as a dict where
|
||||
# keys are sub-config names. But most people will specify one `str` which means that should dispatch it for all sub-models.
|
||||
# See https://github.com/huggingface/transformers/pull/32238
|
||||
for key in self.config.sub_configs.keys():
|
||||
sub_config = getattr(self.config, key)
|
||||
curr_attn_implementation = (
|
||||
requested_attn_implementation
|
||||
if not isinstance(requested_attn_implementation, dict)
|
||||
else requested_attn_implementation.get(key, None)
|
||||
)
|
||||
# For models with backbone sub-config might be not initialized. Set the requested att
|
||||
# if the config hasn't got any attn pre-set and the requested attn in not `None` (i.e not the default attn)
|
||||
if (
|
||||
sub_config is not None
|
||||
and sub_config._attn_implementation_internal is None
|
||||
and curr_attn_implementation is not None
|
||||
):
|
||||
sub_config._attn_implementation_internal = curr_attn_implementation
|
||||
|
||||
if requested_attn_implementation == "flash_attention_3" and self._flash_attn_3_can_dispatch():
|
||||
self.config._attn_implementation = "flash_attention_3"
|
||||
if requested_attn_implementation == "flash_attention_2" and self._flash_attn_2_can_dispatch():
|
||||
self.config._attn_implementation = "flash_attention_2"
|
||||
elif requested_attn_implementation == "flex_attention" and self._flex_attn_can_dispatch():
|
||||
self.config._attn_implementation = "flex_attention"
|
||||
elif (
|
||||
requested_attn_implementation in [None, "sdpa"]
|
||||
and not is_torch_xla_available()
|
||||
and self._sdpa_can_dispatch(hard_check_only=requested_attn_implementation is not None)
|
||||
):
|
||||
self.config._attn_implementation = "sdpa"
|
||||
elif requested_attn_implementation in ALL_ATTENTION_FUNCTIONS.valid_keys():
|
||||
self.config._attn_implementation = requested_attn_implementation
|
||||
elif isinstance(requested_attn_implementation, dict):
|
||||
self.config._attn_implementation = requested_attn_implementation.get("", None)
|
||||
else:
|
||||
self.config._attn_implementation = "eager"
|
||||
|
||||
self.config._attn_implementation_autoset = True
|
||||
|
||||
@classmethod
|
||||
def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
|
||||
"""
|
||||
@@ -2439,15 +2339,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
# Otherwise, can't generate
|
||||
return False
|
||||
|
||||
def _flash_attn_2_can_dispatch(self) -> bool:
|
||||
def _flash_attn_2_can_dispatch(self, is_init_check: bool = False) -> bool:
|
||||
"""
|
||||
Checks the availability of Flash Attention 2 and compatibility with the current model.
|
||||
Check the availability of Flash Attention 2 for a given model.
|
||||
|
||||
If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module.
|
||||
Args:
|
||||
is_init_check (`bool`, *optional*):
|
||||
Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
|
||||
fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses
|
||||
BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early
|
||||
before instantiating the full models if we know that the model does not support the requested attention.
|
||||
"""
|
||||
# Config always has `torch_dtype` but we need the next line for `no_super_init()` tests
|
||||
torch_dtype = self.config.torch_dtype if hasattr(self.config, "torch_dtype") else torch.get_default_dtype()
|
||||
device_map = self.hf_device_map if hasattr(self, "hf_device_map") else None
|
||||
torch_dtype = self.config.torch_dtype
|
||||
|
||||
# check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
|
||||
if not (self._supports_flash_attn or getattr(self, "_supports_flash_attn_2", False)):
|
||||
@@ -2486,68 +2389,62 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
else:
|
||||
raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
|
||||
|
||||
_is_bettertransformer = getattr(self, "use_bettertransformer", False)
|
||||
if _is_bettertransformer:
|
||||
raise ValueError(
|
||||
"Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()"
|
||||
)
|
||||
|
||||
if torch_dtype is None:
|
||||
logger.warning_once(
|
||||
"You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour"
|
||||
"You are attempting to use Flash Attention 2 without specifying a torch dtype. This might lead to unexpected behaviour"
|
||||
)
|
||||
elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
|
||||
logger.warning_once(
|
||||
"Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but"
|
||||
"Flash Attention 2 only supports torch.float16 and torch.bfloat16 dtypes, but"
|
||||
f" the current dype in {self.__class__.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
|
||||
' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`'
|
||||
)
|
||||
|
||||
# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
|
||||
# or the model may be initialized under the context manager `with torch.device("cuda"):`.
|
||||
if device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]:
|
||||
if torch.cuda.is_available():
|
||||
logger.warning_once(
|
||||
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU"
|
||||
" after initializing it on CPU with `model.to('cuda')`."
|
||||
)
|
||||
elif is_torch_mlu_available():
|
||||
logger.warning_once(
|
||||
"You are attempting to use Flash Attention 2.0 with a model not initialized on MLU. Make sure to move the model to MLU"
|
||||
" after initializing it on CPU with `model.to('mlu')`."
|
||||
)
|
||||
else:
|
||||
# With the early check, the parameters are not yet initalized correctly
|
||||
if not is_init_check:
|
||||
if getattr(self, "use_bettertransformer", False):
|
||||
raise ValueError(
|
||||
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU and with no GPU available. "
|
||||
"This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
|
||||
"or initialising the model on CPU and then moving it to GPU."
|
||||
"Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()"
|
||||
)
|
||||
elif (
|
||||
device_map is not None
|
||||
and isinstance(device_map, dict)
|
||||
and ("cpu" in device_map.values() or "disk" in device_map.values())
|
||||
):
|
||||
raise ValueError(
|
||||
"You are attempting to use Flash Attention 2.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to "
|
||||
"initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
|
||||
)
|
||||
|
||||
param_devices = list({param.device for param in self.parameters()})
|
||||
if len(param_devices) == 1 and param_devices[0].type == "cpu":
|
||||
if torch.cuda.is_available():
|
||||
logger.warning_once(
|
||||
"You are attempting to use Flash Attention 2 with a model not initialized on GPU. Make sure to move the model to GPU"
|
||||
" after initializing it on CPU with `model.to('cuda')`."
|
||||
)
|
||||
elif is_torch_mlu_available():
|
||||
logger.warning_once(
|
||||
"You are attempting to use Flash Attention 2 with a model not initialized on MLU. Make sure to move the model to MLU"
|
||||
" after initializing it on CPU with `model.to('mlu')`."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are attempting to use Flash Attention 2 with a model not initialized on GPU and with no GPU available. "
|
||||
"This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
|
||||
"or initialising the model on CPU and then moving it to GPU."
|
||||
)
|
||||
|
||||
# If no error raise by this point, we can return `True`
|
||||
return True
|
||||
|
||||
def _flash_attn_3_can_dispatch(self) -> bool:
|
||||
def _flash_attn_3_can_dispatch(self, is_init_check: bool = False) -> bool:
|
||||
"""
|
||||
Checks the availability of Flash Attention 3 and compatibility with the current model.
|
||||
Check the availability of Flash Attention 3 for a given model.
|
||||
|
||||
If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_3" so that the model can initialize the correct attention module.
|
||||
Args:
|
||||
is_init_check (`bool`, *optional*):
|
||||
Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
|
||||
fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses
|
||||
BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early
|
||||
before instantiating the full models if we know that the model does not support the requested attention.
|
||||
"""
|
||||
# Config always has `torch_dtype` but we need the next line for `no_super_init()` tests
|
||||
torch_dtype = self.config.torch_dtype if hasattr(self.config, "torch_dtype") else torch.get_default_dtype()
|
||||
device_map = self.hf_device_map if hasattr(self, "hf_device_map") else None
|
||||
torch_dtype = self.config.torch_dtype
|
||||
|
||||
if not self._supports_flash_attn:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support Flash Attention 3.0 yet. Please request to add support where"
|
||||
f"{self.__class__.__name__} does not support Flash Attention 3 yet. Please request to add support where"
|
||||
f" the model is hosted, on its model hub page: https://huggingface.co/{self.config._name_or_path}/discussions/new"
|
||||
" or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new"
|
||||
)
|
||||
@@ -2591,48 +2488,43 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
f"Model has attention_dropout={self.config.attention_dropout}, which is not supported by Flash Attention 3."
|
||||
)
|
||||
|
||||
# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
|
||||
# or the model may be initialized under the context manager `with torch.device("cuda"):`.
|
||||
if device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]:
|
||||
if torch.cuda.is_available():
|
||||
logger.warning_once(
|
||||
"You are attempting to use Flash Attention 3 with a model not initialized on GPU. Make sure to move the model to GPU"
|
||||
" after initializing it on CPU with `model.to('cuda')`."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are attempting to use Flash Attention 3 with a model not initialized on GPU and with no GPU available. "
|
||||
"This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
|
||||
"or initialising the model on CPU and then moving it to GPU."
|
||||
)
|
||||
elif (
|
||||
device_map is not None
|
||||
and isinstance(device_map, dict)
|
||||
and ("cpu" in device_map.values() or "disk" in device_map.values())
|
||||
):
|
||||
raise ValueError(
|
||||
"You are attempting to use Flash Attention 3 with a model dispatched on CPU or disk. This is not supported. Please make sure to "
|
||||
"initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
|
||||
)
|
||||
# With the early check, the parameters are not yet initalized correctly
|
||||
if not is_init_check:
|
||||
param_devices = list({param.device for param in self.parameters()})
|
||||
if len(param_devices) == 1 and param_devices[0].type == "cpu":
|
||||
if torch.cuda.is_available():
|
||||
logger.warning_once(
|
||||
"You are attempting to use Flash Attention 3 with a model not initialized on GPU. Make sure to move the model to GPU"
|
||||
" after initializing it on CPU with `model.to('cuda')`."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are attempting to use Flash Attention 3 with a model not initialized on GPU and with no GPU available. "
|
||||
"This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
|
||||
"or initialising the model on CPU and then moving it to GPU."
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def _sdpa_can_dispatch(self, hard_check_only: bool = False) -> bool:
|
||||
def _sdpa_can_dispatch(self, is_init_check: bool = False) -> bool:
|
||||
"""
|
||||
Checks the availability of SDPA for a given model.
|
||||
Check the availability of SDPA for a given model.
|
||||
|
||||
If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "sdpa" so that the model can initialize the correct attention module.
|
||||
Args:
|
||||
is_init_check (`bool`, *optional*):
|
||||
Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
|
||||
fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses
|
||||
BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early
|
||||
before instantiating the full models if we know that the model does not support the requested attention.
|
||||
"""
|
||||
if hard_check_only:
|
||||
if not self._supports_sdpa:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet."
|
||||
" Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe"
|
||||
' this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
|
||||
)
|
||||
if not is_torch_sdpa_available():
|
||||
raise ImportError(
|
||||
"PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1."
|
||||
)
|
||||
if not self._supports_sdpa:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet."
|
||||
" Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe"
|
||||
' this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
|
||||
)
|
||||
if not is_torch_sdpa_available():
|
||||
raise ImportError("PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1.")
|
||||
|
||||
if (
|
||||
torch.version.hip is not None
|
||||
@@ -2644,18 +2536,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
)
|
||||
torch.backends.cuda.enable_flash_sdp(False)
|
||||
|
||||
# This means we have `hard_check_only=False` and fallback to eager if SDPA isn't supported
|
||||
_is_bettertransformer = getattr(self, "use_bettertransformer", False)
|
||||
if not is_torch_sdpa_available() or not self._supports_sdpa or _is_bettertransformer:
|
||||
return False
|
||||
if not is_init_check:
|
||||
if getattr(self, "use_bettertransformer", False):
|
||||
raise ValueError(
|
||||
"SDPA and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def _flex_attn_can_dispatch(self) -> bool:
|
||||
def _flex_attn_can_dispatch(self, is_init_check: bool = False) -> bool:
|
||||
"""
|
||||
Checks the availability of Flex Attention for a given model.
|
||||
Check the availability of Flex Attention for a given model.
|
||||
|
||||
If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flex_attention" so that the model can initialize the correct attention module.
|
||||
Args:
|
||||
is_init_check (`bool`, *optional*):
|
||||
Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
|
||||
fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses
|
||||
BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early
|
||||
before instantiating the full models if we know that the model does not support the requested attention.
|
||||
"""
|
||||
if not self._supports_flex_attn:
|
||||
raise ValueError(
|
||||
@@ -2670,9 +2568,190 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
"PyTorch Flex Attention requirements in Transformers are not met. Please install torch>=2.5.0."
|
||||
)
|
||||
|
||||
if not is_init_check:
|
||||
if getattr(self, "use_bettertransformer", False):
|
||||
raise ValueError(
|
||||
"FlexAttention and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()"
|
||||
)
|
||||
|
||||
# If no error raise by this point, we can return `True`
|
||||
return True
|
||||
|
||||
def _check_and_adjust_attn_implementation(
|
||||
self, attn_implementation: Optional[str], is_init_check: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
Check that the `attn_implementation` exists and is supported by the models, and try to get the kernel from hub if
|
||||
it matches hf kernels pattern.
|
||||
|
||||
Args:
|
||||
attn_implementation (`str` or `None`):
|
||||
The attention implementation to check for existence/validity.
|
||||
is_init_check (`bool`, *optional*):
|
||||
Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
|
||||
fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses
|
||||
BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early
|
||||
before instantiating the full models if we know that the model does not support the requested attention.
|
||||
|
||||
Returns:
|
||||
`str`: The final attention implementation to use, including potential fallbacks from sdpa to eager, or from
|
||||
None to sdpa (to potentially eager).
|
||||
"""
|
||||
applicable_attn_implementation = "sdpa" if attn_implementation is None else attn_implementation
|
||||
if re.match(r"^[^/:]+/[^/:]+:[^/:]+$", applicable_attn_implementation):
|
||||
if not is_kernels_available():
|
||||
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
|
||||
|
||||
# Extract repo_id and kernel_name from the string
|
||||
repo_id, kernel_name = applicable_attn_implementation.split(":")
|
||||
kernel_name = kernel_name.strip()
|
||||
repo_id = repo_id.strip()
|
||||
|
||||
try:
|
||||
kernel = get_kernel(repo_id)
|
||||
ALL_ATTENTION_FUNCTIONS.register(f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name))
|
||||
applicable_attn_implementation = f"kernel_{repo_id.replace('/', '_')}"
|
||||
except FileNotFoundError as e:
|
||||
logger.warning_once(
|
||||
f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using "
|
||||
"default attention implementation instead (sdpa if available, eager otherwise)."
|
||||
)
|
||||
applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"the kernel function name or class specified in the attn_implementation argument is not valid. Please check "
|
||||
"the documentation for the correct format, and check that the kernel exports the class and the function correctly."
|
||||
)
|
||||
if applicable_attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
|
||||
message = (
|
||||
f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are '
|
||||
'`attn_implementation="eager"` (manual attention implementation)'
|
||||
)
|
||||
# check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
|
||||
if self._supports_flash_attn or getattr(self, "_supports_flash_attn_2", False):
|
||||
message += (
|
||||
', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)'
|
||||
', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
|
||||
)
|
||||
if self._supports_sdpa:
|
||||
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
|
||||
if self._supports_flex_attn:
|
||||
message += ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
|
||||
raise ValueError(message + ".")
|
||||
|
||||
# Perform relevant checks
|
||||
if applicable_attn_implementation == "flash_attention_2":
|
||||
self._flash_attn_2_can_dispatch(is_init_check)
|
||||
elif applicable_attn_implementation == "flash_attention_3":
|
||||
self._flash_attn_3_can_dispatch(is_init_check)
|
||||
elif applicable_attn_implementation == "flex_attention":
|
||||
self._flex_attn_can_dispatch(is_init_check)
|
||||
elif applicable_attn_implementation == "sdpa":
|
||||
# Sdpa is the default, so we try it and fallback to eager otherwise when not possible
|
||||
try:
|
||||
self._sdpa_can_dispatch(is_init_check)
|
||||
except (ValueError, ImportError) as e:
|
||||
# In this case, sdpa was requested explicitly, but we can't use it, so let's raise
|
||||
if attn_implementation == "sdpa":
|
||||
raise e
|
||||
applicable_attn_implementation = "eager"
|
||||
|
||||
return applicable_attn_implementation
|
||||
|
||||
@classmethod
|
||||
def _can_set_attn_implementation(cls) -> bool:
|
||||
"""Detect whether the class supports setting its attention implementation dynamically. It is an ugly check based on
|
||||
opening the file, but avoids maintaining yet another property flag.
|
||||
"""
|
||||
class_file = sys.modules[cls.__module__].__file__
|
||||
with open(class_file, "r") as f:
|
||||
code = f.read()
|
||||
# heuristic -> if we find those patterns, the model uses the correct interface
|
||||
return (
|
||||
"eager_attention_forward" in code and "ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]" in code
|
||||
)
|
||||
|
||||
def set_attn_implementation(self, attn_implementation: Union[str, dict]):
|
||||
"""
|
||||
Set the requested `attn_implementation` for this model.
|
||||
|
||||
Args:
|
||||
attn_implementation (`str` or `dict`):
|
||||
The attention implementation to set for this model. It can be either a `str`, in which case it will be
|
||||
dispatched to all submodels if relevant, or a `dict` where keys are the sub_configs name, in which case each
|
||||
submodel will dispatch the corresponding value.
|
||||
"""
|
||||
requested_implementation = (
|
||||
attn_implementation
|
||||
if not isinstance(attn_implementation, dict)
|
||||
else attn_implementation.get("", self.config._attn_implementation)
|
||||
)
|
||||
|
||||
# At this point, the model was already instantiated, so instead of crashing on bad value, let's simply
|
||||
# warn the user that the requested value is not working
|
||||
if requested_implementation != self.config._attn_implementation:
|
||||
# In this case, raise
|
||||
if not self._can_set_attn_implementation():
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__} does not support setting its attention implementation dynamically, because it "
|
||||
"does not follow the functional approach based on AttentionInterface "
|
||||
"(see https://huggingface.co/docs/transformers/en/attention_interface)"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
applicable_attn_implementation = self._check_and_adjust_attn_implementation(
|
||||
requested_implementation, is_init_check=False
|
||||
)
|
||||
# Apply the change (on the internal attr, to avoid setting it recursively)
|
||||
self.config._attn_implementation_internal = applicable_attn_implementation
|
||||
except (ValueError, ImportError) as e:
|
||||
logger.warning(
|
||||
f"Impossible to set the requested `attn_implementation`. The following error was captured: {str(e)}"
|
||||
)
|
||||
|
||||
subconfigs_changed = set()
|
||||
# Apply it to all submodels as well
|
||||
for submodule in self.modules():
|
||||
# We found a submodel (which is not self) with a different config (otherwise, it may be the same "actual model",
|
||||
# e.g. ForCausalLM has a Model inside, but no need to check it again)
|
||||
if (
|
||||
submodule is not self
|
||||
and isinstance(submodule, PreTrainedModel)
|
||||
and submodule.config.__class__ != self.config.__class__
|
||||
):
|
||||
sub_implementation = attn_implementation
|
||||
if isinstance(attn_implementation, dict):
|
||||
for subconfig_key in self.config.sub_configs:
|
||||
# We need to check for exact object match here, with `is`
|
||||
if getattr(self.config, subconfig_key) is submodule.config:
|
||||
sub_implementation = attn_implementation.get(
|
||||
subconfig_key, submodule.config._attn_implementation
|
||||
)
|
||||
break
|
||||
submodule.set_attn_implementation(sub_implementation)
|
||||
subconfigs_changed.add(submodule.config.__class__)
|
||||
|
||||
# We need this as some old and badly designed models use subconfigs without declaring the corresponding modules as PreTrainedModel
|
||||
for subconfig_key in self.config.sub_configs:
|
||||
subconfig = getattr(self.config, subconfig_key)
|
||||
requested_implementation = (
|
||||
attn_implementation
|
||||
if not isinstance(attn_implementation, dict)
|
||||
else attn_implementation.get(subconfig_key, subconfig._attn_implementation)
|
||||
)
|
||||
# This means we did not perform any check above for this particular subconfig -> set it in the dark if it is registered
|
||||
if (
|
||||
subconfig.__class__ not in subconfigs_changed
|
||||
and requested_implementation != subconfig._attn_implementation
|
||||
and requested_implementation in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys()
|
||||
):
|
||||
subconfig._attn_implementation_internal = requested_implementation
|
||||
logger.warning(
|
||||
f"We set the attention implementation for the sub-config `{subconfig_key}` to `{requested_implementation}` "
|
||||
"without finding the associated sub-model. For this reason we could not check if the model supports it. "
|
||||
"You may encounter undefined behavior."
|
||||
)
|
||||
|
||||
def enable_input_require_grads(self):
|
||||
"""
|
||||
Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
|
||||
@@ -4601,21 +4680,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
if "gguf_file" in model_kwargs:
|
||||
model_kwargs.pop("gguf_file")
|
||||
else:
|
||||
# In case one passes a config to `from_pretrained` + "attn_implementation"
|
||||
# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs
|
||||
# Please see: https://github.com/huggingface/transformers/issues/28038
|
||||
|
||||
# Overwrite `config._attn_implementation` by the one from the kwargs --> in auto-factory
|
||||
# we pop attn_implementation from the kwargs but this handles the case where users
|
||||
# passes manually the config to `from_pretrained`.
|
||||
config = copy.deepcopy(config)
|
||||
|
||||
kwarg_attn_imp = kwargs.pop("attn_implementation", None)
|
||||
if kwarg_attn_imp is not None:
|
||||
config._attn_implementation = kwarg_attn_imp
|
||||
|
||||
model_kwargs = kwargs
|
||||
|
||||
# Because some composite configs call super().__init__ before instantiating the sub-configs, we need this call
|
||||
# to correctly redispatch recursively if the kwarg is provided
|
||||
if "attn_implementation" in kwargs:
|
||||
config._attn_implementation = kwargs.pop("attn_implementation")
|
||||
|
||||
transformers_explicit_filename = getattr(config, "transformers_weights", None)
|
||||
|
||||
if transformers_explicit_filename is not None:
|
||||
|
||||
Reference in New Issue
Block a user