[refactor] set attention implementation (#38974)

* update

* fix some tests

* init from config, changes it in-place, add deepcopy in tests

* fix modernbert

* don't delete thsi config attr

* update

* style and copies

* skip tests in generation

* fix style

* accidentally removed flash-attn-3, revert

* docs

* forgot about flags set to False

* fix copies

* address a few comments

* fix copies

* custom code BC
This commit is contained in:
Raushan Turganbay
2025-07-15 12:34:06 +05:00
committed by GitHub
parent 6017f5e8ed
commit 8d6259b0b8
185 changed files with 451 additions and 776 deletions

View File

@@ -1961,11 +1961,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
supports_gradient_checkpointing = False
_is_stateful = False
# Flash Attention 2 support
_supports_flash_attn_2 = False
# Flash Attention 3 support
_supports_flash_attn_3 = False
# Flash Attention support
_supports_flash_attn = False
# SDPA support
_supports_sdpa = False
@@ -2074,12 +2071,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
"`PretrainedConfig`. To create a model from a pretrained model use "
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
if not getattr(config, "_attn_implementation_autoset", False):
# config usually has a `torch_dtype` but we need the next line for the `no_super_init` tests
dtype = config.torch_dtype if hasattr(config, "torch_dtype") else torch.get_default_dtype()
config = self._autoset_attn_implementation(config, torch_dtype=dtype, check_device_map=False)
self.config = config
# The `hasattr` here is used as some Transformers tests for some reason do not call
# PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
if hasattr(config, "_attn_implementation_internal") and not getattr(
config, "_attn_implementation_autoset", False
):
self.set_attention_implementation(self.config._attn_implementation_internal)
# for initialization of the loss
loss_type = self.__class__.__name__
if loss_type not in LOSS_MAPPING:
@@ -2226,19 +2226,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config.
if config._attn_implementation_internal is not None:
# In this case, the config has been created with the attn_implementation set by the user, which we
# should respect.
# In this case, the config has been created with the attn_implementation set by the user, which we should respect.
attn_implementation = config._attn_implementation_internal
else:
attn_implementation = None
config._attn_implementation = kwargs.pop("attn_implementation", attn_implementation)
if not getattr(config, "_attn_implementation_autoset", False):
config = cls._autoset_attn_implementation(
config,
check_device_map=False,
torch_dtype=torch_dtype,
)
if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called:
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
@@ -2260,81 +2252,65 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
return model
@classmethod
def _autoset_attn_implementation(
cls,
config,
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, dict[str, int]]] = None,
check_device_map: bool = True,
):
def _check_attn_implementation(cls, attn_implementation: Union[str, dict]) -> Union[str, dict]:
"""
Automatically checks and dispatches to a default attention implementation. In order of priority:
1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained).
2. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example)
3. The default model's implementation otherwise (`LlamaAttention` for example) .
Checks that the requested attention implementation exists and tries to get the kernel from hub
if `attn_implementation` matches hf kernels pattern.
"""
# Here we use config._attn_implementation_internal to check whether the attention implementation was explicitly set by the user.
# The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager").
# The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
requested_attn_implementation = None
if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None:
if isinstance(config._attn_implementation, str) and re.match(
r"^[^/:]+/[^/:]+:[^/:]+$", config._attn_implementation
):
if not is_kernels_available():
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
if isinstance(attn_implementation, str) and re.match(r"^[^/:]+/[^/:]+:[^/:]+$", attn_implementation):
if not is_kernels_available():
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
# Extract repo_id and kernel_name from the string
repo_id, kernel_name = config._attn_implementation.split(":")
kernel_name = kernel_name.strip()
repo_id = repo_id.strip()
# Extract repo_id and kernel_name from the string
repo_id, kernel_name = attn_implementation.split(":")
kernel_name = kernel_name.strip()
repo_id = repo_id.strip()
try:
kernel = get_kernel(repo_id)
ALL_ATTENTION_FUNCTIONS.register(
f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name)
)
config._attn_implementation = f"kernel_{repo_id.replace('/', '_')}"
except FileNotFoundError as e:
logger.warning(
f"Could not find a kernel repository '{repo_id}' compatible with your devicein the hub: {e}. Using eager attention implementation instead."
)
config._attn_implementation = "eager"
except AttributeError:
raise ValueError(
"the kernel function name or class specified in the attn_implementation argument is not valid. \
Please check the documentation for the correct format, \
and check that the kernel exports the class and the function correctly."
)
try:
kernel = get_kernel(repo_id)
ALL_ATTENTION_FUNCTIONS.register(f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name))
attn_implementation = f"kernel_{repo_id.replace('/', '_')}"
except FileNotFoundError as e:
logger.warning(
f"Could not find a kernel repository '{repo_id}' compatible with your devicein the hub: {e}. Using eager attention implementation instead."
)
attn_implementation = None # try to dispatch SDPA and fallback eager if not available
except AttributeError:
raise ValueError(
"the kernel function name or class specified in the attn_implementation argument is not valid. \
Please check the documentation for the correct format, \
and check that the kernel exports the class and the function correctly."
)
if (
not isinstance(attn_implementation, dict)
and attn_implementation not in ["eager", None] + ALL_ATTENTION_FUNCTIONS.valid_keys()
):
message = f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
# check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
if cls._supports_flash_attn or getattr(cls, "_supports_flash_attn_2", False):
message += (
', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)'
', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
)
if cls._supports_sdpa:
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
if cls._supports_flex_attn:
message += ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
raise ValueError(message + ".")
if (
not isinstance(config._attn_implementation, dict)
and config._attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys()
):
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
if cls._supports_flash_attn_3:
message += ', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)'
if cls._supports_flash_attn_2:
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
if cls._supports_sdpa:
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
if cls._supports_flex_attn:
message += (
', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
)
raise ValueError(message + ".")
return attn_implementation
# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
requested_attn_implementation = config._attn_implementation_internal
def set_attention_implementation(self, attn_implementation: Union[str, dict]):
"""
Checks and dispatches to the requested attention implementation.
"""
requested_attn_implementation = self._check_attn_implementation(attn_implementation)
# Composite models consisting of several PretrainedModels have to specify attention impl as a dict
# where keys are sub-config names. But most people will specify one `str` which means that should dispatch it
# for all sub-models.
# Below we check if a config is composite and manually prepare a dict of attn impl if not already passed as a dict.
# Later each sub-module will dispatch with its own attn impl, by calling `XXXModel._from_config(config.text_config)`
# If any of sub-modules doesn't support requested attn, an error will be raised. See https://github.com/huggingface/transformers/pull/32238
for key in config.sub_configs.keys():
sub_config = getattr(config, key)
# Composite models consisting of several PretrainedModels can specify attention implementation as a dict where
# keys are sub-config names. But most people will specify one `str` which means that should dispatch it for all sub-models.
# See https://github.com/huggingface/transformers/pull/32238
for key in self.config.sub_configs.keys():
sub_config = getattr(self.config, key)
curr_attn_implementation = (
requested_attn_implementation
if not isinstance(requested_attn_implementation, dict)
@@ -2349,50 +2325,26 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
):
sub_config._attn_implementation_internal = curr_attn_implementation
if config._attn_implementation == "flash_attention_3":
cls._check_and_enable_flash_attn_3(
config,
torch_dtype=torch_dtype,
device_map=device_map,
hard_check_only=False,
check_device_map=check_device_map,
)
elif config._attn_implementation == "flash_attention_2":
cls._check_and_enable_flash_attn_2(
config,
torch_dtype=torch_dtype,
device_map=device_map,
hard_check_only=False,
check_device_map=check_device_map,
)
elif requested_attn_implementation == "flex_attention":
config = cls._check_and_enable_flex_attn(config, hard_check_only=True)
elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available():
# flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
config = cls._check_and_enable_sdpa(
config,
hard_check_only=requested_attn_implementation is not None,
)
if (
torch.version.hip is not None
and config._attn_implementation == "sdpa"
and torch.cuda.device_count() > 1
and version.parse(torch.__version__) < version.parse("2.4.1")
):
logger.warning_once(
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
)
torch.backends.cuda.enable_flash_sdp(False)
if requested_attn_implementation == "flash_attention_3" and self._flash_attn_3_can_dispatch():
self.config._attn_implementation = "flash_attention_3"
if requested_attn_implementation == "flash_attention_2" and self._flash_attn_2_can_dispatch():
self.config._attn_implementation = "flash_attention_2"
elif requested_attn_implementation == "flex_attention" and self._flex_attn_can_dispatch():
self.config._attn_implementation = "flex_attention"
elif (
requested_attn_implementation in [None, "sdpa"]
and not is_torch_xla_available()
and self._sdpa_can_dispatch(hard_check_only=requested_attn_implementation is not None)
):
self.config._attn_implementation = "sdpa"
elif requested_attn_implementation in ALL_ATTENTION_FUNCTIONS.valid_keys():
config._attn_implementation = requested_attn_implementation
self.config._attn_implementation = requested_attn_implementation
elif isinstance(requested_attn_implementation, dict):
config._attn_implementation = None
self.config._attn_implementation = requested_attn_implementation.get("", None)
else:
config._attn_implementation = "eager"
self.config._attn_implementation = "eager"
config._attn_implementation_autoset = True
return config
self.config._attn_implementation_autoset = True
@classmethod
def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
@@ -2466,24 +2418,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# Otherwise, can't generate
return False
@classmethod
def _check_and_enable_flash_attn_2(
cls,
config,
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, dict[str, int]]] = None,
check_device_map: bool = True,
hard_check_only: bool = False,
) -> PretrainedConfig:
def _flash_attn_2_can_dispatch(self) -> bool:
"""
Checks the availability of Flash Attention 2 and compatibility with the current model.
If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module.
"""
if not cls._supports_flash_attn_2:
# Config always has `torch_dtype` but we need the next line for `no_super_init()` tests
torch_dtype = self.config.torch_dtype if hasattr(self.config, "torch_dtype") else torch.get_default_dtype()
device_map = self.hf_device_map if hasattr(self, "hf_device_map") else None
# check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
if not (self._supports_flash_attn or getattr(self, "_supports_flash_attn_2", False)):
raise ValueError(
f"{cls.__name__} does not support Flash Attention 2.0 yet. Please request to add support where"
f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new"
f"{self.__class__.__name__} does not support Flash Attention 2.0 yet. Please request to add support where"
f" the model is hosted, on its model hub page: https://huggingface.co/{self.config._name_or_path}/discussions/new"
" or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new"
)
@@ -2491,39 +2440,32 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:"
install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
if importlib.util.find_spec("flash_attn") is None:
# package `flash-attn` can not be installed on Ascend NPU, ignore related validation logic and early exit.
if is_torch_npu_available():
if not hard_check_only:
config._attn_implementation = "flash_attention_2"
logger.info("Detect using FlashAttention2 on Ascend NPU.")
return config
else:
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
if torch.version.cuda:
if flash_attention_version < version.parse("2.1.0"):
raise ImportError(
f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}"
)
elif not torch.cuda.is_available():
raise ValueError(
f"{preface} Flash Attention 2 is not available on CPU. Please make sure torch can access a CUDA device."
)
else:
raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
elif torch.version.hip:
if flash_attention_version < version.parse("2.0.4"):
raise ImportError(
f"{preface} you need flash_attn package version to be greater or equal than 2.0.4. Make sure to have that version installed - detected version {flash_attention_version}. {install_message}"
)
else:
raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
# package `flash-attn` can not be installed on Ascend NPU, ignore related validation logi
if importlib.util.find_spec("flash_attn") is None and not is_torch_npu_available():
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
else:
# Check FA2 installed version compatibility
flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
if torch.version.cuda:
if flash_attention_version < version.parse("2.1.0"):
raise ImportError(
f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}"
)
elif not torch.cuda.is_available():
raise ValueError(
f"{preface} Flash Attention 2 is not available on CPU. Please make sure torch can access a CUDA device."
)
else:
raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
elif torch.version.hip:
if flash_attention_version < version.parse("2.0.4"):
raise ImportError(
f"{preface} you need flash_attn package version to be greater or equal than 2.0.4. Detected version {flash_attention_version}. {install_message}"
)
else:
raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
_is_bettertransformer = getattr(self, "use_bettertransformer", False)
if _is_bettertransformer:
raise ValueError(
"Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()"
@@ -2536,13 +2478,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
logger.warning_once(
"Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but"
f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
f" the current dype in {self.__class__.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`'
)
# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
# or the model may be initialized under the context manager `with torch.device("cuda"):`.
if check_device_map and device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]:
if device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]:
if torch.cuda.is_available():
logger.warning_once(
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU"
@@ -2560,8 +2502,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
"or initialising the model on CPU and then moving it to GPU."
)
elif (
check_device_map
and device_map is not None
device_map is not None
and isinstance(device_map, dict)
and ("cpu" in device_map.values() or "disk" in device_map.values())
):
@@ -2569,28 +2510,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
"You are attempting to use Flash Attention 2.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to "
"initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
)
if not hard_check_only:
config._attn_implementation = "flash_attention_2"
return config
@classmethod
def _check_and_enable_flash_attn_3(
cls,
config,
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, dict[str, int]]] = None,
check_device_map: bool = True,
hard_check_only: bool = False,
) -> PretrainedConfig:
# If no error raise by this point, we can return `True`
return True
def _flash_attn_3_can_dispatch(self) -> bool:
"""
Checks the availability of Flash Attention 3 and compatibility with the current model.
If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_3" so that the model can initialize the correct attention module.
"""
if not cls._supports_flash_attn_3:
# Config always has `torch_dtype` but we need the next line for `no_super_init()` tests
torch_dtype = self.config.torch_dtype if hasattr(self.config, "torch_dtype") else torch.get_default_dtype()
device_map = self.hf_device_map if hasattr(self, "hf_device_map") else None
if not self._supports_flash_attn:
raise ValueError(
f"{cls.__name__} does not support Flash Attention 3.0 yet. Please request to add support where"
f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new"
f"{self.__class__.__name__} does not support Flash Attention 3.0 yet. Please request to add support where"
f" the model is hosted, on its model hub page: https://huggingface.co/{self.config._name_or_path}/discussions/new"
" or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new"
)
@@ -2620,22 +2557,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
logger.warning_once(
"Flash Attention 3 only supports torch.float16 and torch.bfloat16 dtypes, but"
f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
f" the current dype in {self.__class__.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B", attn_implementation="flash_attention_3", torch_dtype=torch.float16)`'
)
if getattr(config, "alibi", False) or getattr(config, "use_alibi", False):
if getattr(self.config, "alibi", False) or getattr(self.config, "use_alibi", False):
raise ValueError("Model is configured to use ALiBi, which is not supported by Flash Attention 3.")
# Check for attention dropout, which is incompatible with FA3
if hasattr(config, "attention_dropout") and config.attention_dropout > 0:
if hasattr(self.config, "attention_dropout") and self.config.attention_dropout > 0:
raise ValueError(
f"Model has attention_dropout={config.attention_dropout}, which is not supported by Flash Attention 3."
f"Model has attention_dropout={self.config.attention_dropout}, which is not supported by Flash Attention 3."
)
# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
# or the model may be initialized under the context manager `with torch.device("cuda"):`.
if check_device_map and device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]:
if device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]:
if torch.cuda.is_available():
logger.warning_once(
"You are attempting to use Flash Attention 3 with a model not initialized on GPU. Make sure to move the model to GPU"
@@ -2648,8 +2585,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
"or initialising the model on CPU and then moving it to GPU."
)
elif (
check_device_map
and device_map is not None
device_map is not None
and isinstance(device_map, dict)
and ("cpu" in device_map.values() or "disk" in device_map.values())
):
@@ -2657,21 +2593,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
"You are attempting to use Flash Attention 3 with a model dispatched on CPU or disk. This is not supported. Please make sure to "
"initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
)
if not hard_check_only:
config._attn_implementation = "flash_attention_3"
return config
return True
@classmethod
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False):
def _sdpa_can_dispatch(self, hard_check_only: bool = False) -> bool:
"""
Checks the availability of SDPA for a given model.
If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "sdpa" so that the model can initialize the correct attention module.
"""
if hard_check_only:
if not cls._supports_sdpa:
if not self._supports_sdpa:
raise ValueError(
f"{cls.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet."
f"{self.__class__.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet."
" Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe"
' this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
)
@@ -2680,45 +2613,44 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
"PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1."
)
if not is_torch_sdpa_available() or not cls._supports_sdpa:
return config
if (
torch.version.hip is not None
and torch.cuda.device_count() > 1
and version.parse(torch.__version__) < version.parse("2.4.1")
):
logger.warning_once(
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
)
torch.backends.cuda.enable_flash_sdp(False)
_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
if _is_bettertransformer:
return config
# This means we have `hard_check_only=False` and fallback to eager if SDPA isn't supported
_is_bettertransformer = getattr(self, "use_bettertransformer", False)
if not is_torch_sdpa_available() or not self._supports_sdpa or _is_bettertransformer:
return False
if not hard_check_only:
config._attn_implementation = "sdpa"
return config
return True
@classmethod
def _check_and_enable_flex_attn(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
def _flex_attn_can_dispatch(self) -> bool:
"""
Checks the availability of Flex Attention for a given model.
If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flex_attention" so that the model can initialize the correct attention module.
"""
if hard_check_only:
if not cls._supports_flex_attn:
raise ValueError(
f"{cls.__name__} does not support an attention implementation through torch's flex_attention."
" Please request the support for this architecture: https://github.com/huggingface/transformers/issues/34809."
" If you believe this error is a bug, please open an issue in Transformers GitHub repository"
' and load your model with the argument `attn_implementation="eager"` meanwhile.'
' Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
)
if not is_torch_flex_attn_available():
raise ImportError(
"PyTorch Flex Attention requirements in Transformers are not met. Please install torch>=2.5.0."
)
if not self._supports_flex_attn:
raise ValueError(
f"{self.__class__.__name__} does not support an attention implementation through torch's flex_attention."
" Please request the support for this architecture: https://github.com/huggingface/transformers/issues/34809."
" If you believe this error is a bug, please open an issue in Transformers GitHub repository"
' and load your model with the argument `attn_implementation="eager"` meanwhile.'
' Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
)
if not is_torch_flex_attn_available():
raise ImportError(
"PyTorch Flex Attention requirements in Transformers are not met. Please install torch>=2.5.0."
)
if not is_torch_flex_attn_available() or not cls._supports_flex_attn:
return config
if not hard_check_only:
config._attn_implementation = "flex_attention"
return config
# If no error raise by this point, we can return `True`
return True
def enable_input_require_grads(self):
"""
@@ -4803,13 +4735,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
model_init_context = cls.get_init_context(is_quantized, _is_ds_init_called)
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
if not getattr(config, "_attn_implementation_autoset", False):
config = cls._autoset_attn_implementation(
config,
torch_dtype=torch_dtype,
device_map=device_map,
)
with ContextManagers(model_init_context):
# Let's make sure we don't run the init function of buffer modules
model = cls(config, *model_args, **model_kwargs)