[refactor] set attention implementation (#38974)
* update * fix some tests * init from config, changes it in-place, add deepcopy in tests * fix modernbert * don't delete thsi config attr * update * style and copies * skip tests in generation * fix style * accidentally removed flash-attn-3, revert * docs * forgot about flags set to False * fix copies * address a few comments * fix copies * custom code BC
This commit is contained in:
committed by
GitHub
parent
6017f5e8ed
commit
8d6259b0b8
@@ -1961,11 +1961,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
supports_gradient_checkpointing = False
|
||||
_is_stateful = False
|
||||
|
||||
# Flash Attention 2 support
|
||||
_supports_flash_attn_2 = False
|
||||
|
||||
# Flash Attention 3 support
|
||||
_supports_flash_attn_3 = False
|
||||
# Flash Attention support
|
||||
_supports_flash_attn = False
|
||||
|
||||
# SDPA support
|
||||
_supports_sdpa = False
|
||||
@@ -2074,12 +2071,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
"`PretrainedConfig`. To create a model from a pretrained model use "
|
||||
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
||||
)
|
||||
if not getattr(config, "_attn_implementation_autoset", False):
|
||||
# config usually has a `torch_dtype` but we need the next line for the `no_super_init` tests
|
||||
dtype = config.torch_dtype if hasattr(config, "torch_dtype") else torch.get_default_dtype()
|
||||
config = self._autoset_attn_implementation(config, torch_dtype=dtype, check_device_map=False)
|
||||
self.config = config
|
||||
|
||||
# The `hasattr` here is used as some Transformers tests for some reason do not call
|
||||
# PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
|
||||
if hasattr(config, "_attn_implementation_internal") and not getattr(
|
||||
config, "_attn_implementation_autoset", False
|
||||
):
|
||||
self.set_attention_implementation(self.config._attn_implementation_internal)
|
||||
|
||||
# for initialization of the loss
|
||||
loss_type = self.__class__.__name__
|
||||
if loss_type not in LOSS_MAPPING:
|
||||
@@ -2226,19 +2226,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config.
|
||||
|
||||
if config._attn_implementation_internal is not None:
|
||||
# In this case, the config has been created with the attn_implementation set by the user, which we
|
||||
# should respect.
|
||||
# In this case, the config has been created with the attn_implementation set by the user, which we should respect.
|
||||
attn_implementation = config._attn_implementation_internal
|
||||
else:
|
||||
attn_implementation = None
|
||||
|
||||
config._attn_implementation = kwargs.pop("attn_implementation", attn_implementation)
|
||||
if not getattr(config, "_attn_implementation_autoset", False):
|
||||
config = cls._autoset_attn_implementation(
|
||||
config,
|
||||
check_device_map=False,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
|
||||
if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called:
|
||||
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
||||
@@ -2260,81 +2252,65 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def _autoset_attn_implementation(
|
||||
cls,
|
||||
config,
|
||||
torch_dtype: Optional[torch.dtype] = None,
|
||||
device_map: Optional[Union[str, dict[str, int]]] = None,
|
||||
check_device_map: bool = True,
|
||||
):
|
||||
def _check_attn_implementation(cls, attn_implementation: Union[str, dict]) -> Union[str, dict]:
|
||||
"""
|
||||
Automatically checks and dispatches to a default attention implementation. In order of priority:
|
||||
1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained).
|
||||
2. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example)
|
||||
3. The default model's implementation otherwise (`LlamaAttention` for example) .
|
||||
Checks that the requested attention implementation exists and tries to get the kernel from hub
|
||||
if `attn_implementation` matches hf kernels pattern.
|
||||
"""
|
||||
# Here we use config._attn_implementation_internal to check whether the attention implementation was explicitly set by the user.
|
||||
# The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager").
|
||||
# The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
|
||||
requested_attn_implementation = None
|
||||
if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None:
|
||||
if isinstance(config._attn_implementation, str) and re.match(
|
||||
r"^[^/:]+/[^/:]+:[^/:]+$", config._attn_implementation
|
||||
):
|
||||
if not is_kernels_available():
|
||||
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
|
||||
if isinstance(attn_implementation, str) and re.match(r"^[^/:]+/[^/:]+:[^/:]+$", attn_implementation):
|
||||
if not is_kernels_available():
|
||||
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
|
||||
|
||||
# Extract repo_id and kernel_name from the string
|
||||
repo_id, kernel_name = config._attn_implementation.split(":")
|
||||
kernel_name = kernel_name.strip()
|
||||
repo_id = repo_id.strip()
|
||||
# Extract repo_id and kernel_name from the string
|
||||
repo_id, kernel_name = attn_implementation.split(":")
|
||||
kernel_name = kernel_name.strip()
|
||||
repo_id = repo_id.strip()
|
||||
|
||||
try:
|
||||
kernel = get_kernel(repo_id)
|
||||
ALL_ATTENTION_FUNCTIONS.register(
|
||||
f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name)
|
||||
)
|
||||
config._attn_implementation = f"kernel_{repo_id.replace('/', '_')}"
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(
|
||||
f"Could not find a kernel repository '{repo_id}' compatible with your devicein the hub: {e}. Using eager attention implementation instead."
|
||||
)
|
||||
config._attn_implementation = "eager"
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"the kernel function name or class specified in the attn_implementation argument is not valid. \
|
||||
Please check the documentation for the correct format, \
|
||||
and check that the kernel exports the class and the function correctly."
|
||||
)
|
||||
try:
|
||||
kernel = get_kernel(repo_id)
|
||||
ALL_ATTENTION_FUNCTIONS.register(f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name))
|
||||
attn_implementation = f"kernel_{repo_id.replace('/', '_')}"
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(
|
||||
f"Could not find a kernel repository '{repo_id}' compatible with your devicein the hub: {e}. Using eager attention implementation instead."
|
||||
)
|
||||
attn_implementation = None # try to dispatch SDPA and fallback eager if not available
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"the kernel function name or class specified in the attn_implementation argument is not valid. \
|
||||
Please check the documentation for the correct format, \
|
||||
and check that the kernel exports the class and the function correctly."
|
||||
)
|
||||
if (
|
||||
not isinstance(attn_implementation, dict)
|
||||
and attn_implementation not in ["eager", None] + ALL_ATTENTION_FUNCTIONS.valid_keys()
|
||||
):
|
||||
message = f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
|
||||
# check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
|
||||
if cls._supports_flash_attn or getattr(cls, "_supports_flash_attn_2", False):
|
||||
message += (
|
||||
', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)'
|
||||
', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
|
||||
)
|
||||
if cls._supports_sdpa:
|
||||
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
|
||||
if cls._supports_flex_attn:
|
||||
message += ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
|
||||
raise ValueError(message + ".")
|
||||
|
||||
if (
|
||||
not isinstance(config._attn_implementation, dict)
|
||||
and config._attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys()
|
||||
):
|
||||
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
|
||||
if cls._supports_flash_attn_3:
|
||||
message += ', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)'
|
||||
if cls._supports_flash_attn_2:
|
||||
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
|
||||
if cls._supports_sdpa:
|
||||
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
|
||||
if cls._supports_flex_attn:
|
||||
message += (
|
||||
', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
|
||||
)
|
||||
raise ValueError(message + ".")
|
||||
return attn_implementation
|
||||
|
||||
# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
|
||||
requested_attn_implementation = config._attn_implementation_internal
|
||||
def set_attention_implementation(self, attn_implementation: Union[str, dict]):
|
||||
"""
|
||||
Checks and dispatches to the requested attention implementation.
|
||||
"""
|
||||
requested_attn_implementation = self._check_attn_implementation(attn_implementation)
|
||||
|
||||
# Composite models consisting of several PretrainedModels have to specify attention impl as a dict
|
||||
# where keys are sub-config names. But most people will specify one `str` which means that should dispatch it
|
||||
# for all sub-models.
|
||||
# Below we check if a config is composite and manually prepare a dict of attn impl if not already passed as a dict.
|
||||
# Later each sub-module will dispatch with its own attn impl, by calling `XXXModel._from_config(config.text_config)`
|
||||
# If any of sub-modules doesn't support requested attn, an error will be raised. See https://github.com/huggingface/transformers/pull/32238
|
||||
for key in config.sub_configs.keys():
|
||||
sub_config = getattr(config, key)
|
||||
# Composite models consisting of several PretrainedModels can specify attention implementation as a dict where
|
||||
# keys are sub-config names. But most people will specify one `str` which means that should dispatch it for all sub-models.
|
||||
# See https://github.com/huggingface/transformers/pull/32238
|
||||
for key in self.config.sub_configs.keys():
|
||||
sub_config = getattr(self.config, key)
|
||||
curr_attn_implementation = (
|
||||
requested_attn_implementation
|
||||
if not isinstance(requested_attn_implementation, dict)
|
||||
@@ -2349,50 +2325,26 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
):
|
||||
sub_config._attn_implementation_internal = curr_attn_implementation
|
||||
|
||||
if config._attn_implementation == "flash_attention_3":
|
||||
cls._check_and_enable_flash_attn_3(
|
||||
config,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=device_map,
|
||||
hard_check_only=False,
|
||||
check_device_map=check_device_map,
|
||||
)
|
||||
elif config._attn_implementation == "flash_attention_2":
|
||||
cls._check_and_enable_flash_attn_2(
|
||||
config,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=device_map,
|
||||
hard_check_only=False,
|
||||
check_device_map=check_device_map,
|
||||
)
|
||||
elif requested_attn_implementation == "flex_attention":
|
||||
config = cls._check_and_enable_flex_attn(config, hard_check_only=True)
|
||||
elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available():
|
||||
# flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
|
||||
config = cls._check_and_enable_sdpa(
|
||||
config,
|
||||
hard_check_only=requested_attn_implementation is not None,
|
||||
)
|
||||
|
||||
if (
|
||||
torch.version.hip is not None
|
||||
and config._attn_implementation == "sdpa"
|
||||
and torch.cuda.device_count() > 1
|
||||
and version.parse(torch.__version__) < version.parse("2.4.1")
|
||||
):
|
||||
logger.warning_once(
|
||||
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
|
||||
)
|
||||
torch.backends.cuda.enable_flash_sdp(False)
|
||||
if requested_attn_implementation == "flash_attention_3" and self._flash_attn_3_can_dispatch():
|
||||
self.config._attn_implementation = "flash_attention_3"
|
||||
if requested_attn_implementation == "flash_attention_2" and self._flash_attn_2_can_dispatch():
|
||||
self.config._attn_implementation = "flash_attention_2"
|
||||
elif requested_attn_implementation == "flex_attention" and self._flex_attn_can_dispatch():
|
||||
self.config._attn_implementation = "flex_attention"
|
||||
elif (
|
||||
requested_attn_implementation in [None, "sdpa"]
|
||||
and not is_torch_xla_available()
|
||||
and self._sdpa_can_dispatch(hard_check_only=requested_attn_implementation is not None)
|
||||
):
|
||||
self.config._attn_implementation = "sdpa"
|
||||
elif requested_attn_implementation in ALL_ATTENTION_FUNCTIONS.valid_keys():
|
||||
config._attn_implementation = requested_attn_implementation
|
||||
self.config._attn_implementation = requested_attn_implementation
|
||||
elif isinstance(requested_attn_implementation, dict):
|
||||
config._attn_implementation = None
|
||||
self.config._attn_implementation = requested_attn_implementation.get("", None)
|
||||
else:
|
||||
config._attn_implementation = "eager"
|
||||
self.config._attn_implementation = "eager"
|
||||
|
||||
config._attn_implementation_autoset = True
|
||||
return config
|
||||
self.config._attn_implementation_autoset = True
|
||||
|
||||
@classmethod
|
||||
def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
|
||||
@@ -2466,24 +2418,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
# Otherwise, can't generate
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _check_and_enable_flash_attn_2(
|
||||
cls,
|
||||
config,
|
||||
torch_dtype: Optional[torch.dtype] = None,
|
||||
device_map: Optional[Union[str, dict[str, int]]] = None,
|
||||
check_device_map: bool = True,
|
||||
hard_check_only: bool = False,
|
||||
) -> PretrainedConfig:
|
||||
def _flash_attn_2_can_dispatch(self) -> bool:
|
||||
"""
|
||||
Checks the availability of Flash Attention 2 and compatibility with the current model.
|
||||
|
||||
If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module.
|
||||
"""
|
||||
if not cls._supports_flash_attn_2:
|
||||
# Config always has `torch_dtype` but we need the next line for `no_super_init()` tests
|
||||
torch_dtype = self.config.torch_dtype if hasattr(self.config, "torch_dtype") else torch.get_default_dtype()
|
||||
device_map = self.hf_device_map if hasattr(self, "hf_device_map") else None
|
||||
|
||||
# check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
|
||||
if not (self._supports_flash_attn or getattr(self, "_supports_flash_attn_2", False)):
|
||||
raise ValueError(
|
||||
f"{cls.__name__} does not support Flash Attention 2.0 yet. Please request to add support where"
|
||||
f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new"
|
||||
f"{self.__class__.__name__} does not support Flash Attention 2.0 yet. Please request to add support where"
|
||||
f" the model is hosted, on its model hub page: https://huggingface.co/{self.config._name_or_path}/discussions/new"
|
||||
" or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new"
|
||||
)
|
||||
|
||||
@@ -2491,39 +2440,32 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:"
|
||||
install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
|
||||
|
||||
if importlib.util.find_spec("flash_attn") is None:
|
||||
# package `flash-attn` can not be installed on Ascend NPU, ignore related validation logic and early exit.
|
||||
if is_torch_npu_available():
|
||||
if not hard_check_only:
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
|
||||
logger.info("Detect using FlashAttention2 on Ascend NPU.")
|
||||
return config
|
||||
else:
|
||||
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
|
||||
|
||||
flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
|
||||
if torch.version.cuda:
|
||||
if flash_attention_version < version.parse("2.1.0"):
|
||||
raise ImportError(
|
||||
f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}"
|
||||
)
|
||||
elif not torch.cuda.is_available():
|
||||
raise ValueError(
|
||||
f"{preface} Flash Attention 2 is not available on CPU. Please make sure torch can access a CUDA device."
|
||||
)
|
||||
else:
|
||||
raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
|
||||
elif torch.version.hip:
|
||||
if flash_attention_version < version.parse("2.0.4"):
|
||||
raise ImportError(
|
||||
f"{preface} you need flash_attn package version to be greater or equal than 2.0.4. Make sure to have that version installed - detected version {flash_attention_version}. {install_message}"
|
||||
)
|
||||
else:
|
||||
raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
|
||||
|
||||
_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
|
||||
# package `flash-attn` can not be installed on Ascend NPU, ignore related validation logi
|
||||
if importlib.util.find_spec("flash_attn") is None and not is_torch_npu_available():
|
||||
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
|
||||
else:
|
||||
# Check FA2 installed version compatibility
|
||||
flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
|
||||
if torch.version.cuda:
|
||||
if flash_attention_version < version.parse("2.1.0"):
|
||||
raise ImportError(
|
||||
f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}"
|
||||
)
|
||||
elif not torch.cuda.is_available():
|
||||
raise ValueError(
|
||||
f"{preface} Flash Attention 2 is not available on CPU. Please make sure torch can access a CUDA device."
|
||||
)
|
||||
else:
|
||||
raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
|
||||
elif torch.version.hip:
|
||||
if flash_attention_version < version.parse("2.0.4"):
|
||||
raise ImportError(
|
||||
f"{preface} you need flash_attn package version to be greater or equal than 2.0.4. Detected version {flash_attention_version}. {install_message}"
|
||||
)
|
||||
else:
|
||||
raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
|
||||
|
||||
_is_bettertransformer = getattr(self, "use_bettertransformer", False)
|
||||
if _is_bettertransformer:
|
||||
raise ValueError(
|
||||
"Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()"
|
||||
@@ -2536,13 +2478,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
|
||||
logger.warning_once(
|
||||
"Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but"
|
||||
f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
|
||||
f" the current dype in {self.__class__.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
|
||||
' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`'
|
||||
)
|
||||
|
||||
# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
|
||||
# or the model may be initialized under the context manager `with torch.device("cuda"):`.
|
||||
if check_device_map and device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]:
|
||||
if device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]:
|
||||
if torch.cuda.is_available():
|
||||
logger.warning_once(
|
||||
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU"
|
||||
@@ -2560,8 +2502,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
"or initialising the model on CPU and then moving it to GPU."
|
||||
)
|
||||
elif (
|
||||
check_device_map
|
||||
and device_map is not None
|
||||
device_map is not None
|
||||
and isinstance(device_map, dict)
|
||||
and ("cpu" in device_map.values() or "disk" in device_map.values())
|
||||
):
|
||||
@@ -2569,28 +2510,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
"You are attempting to use Flash Attention 2.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to "
|
||||
"initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
|
||||
)
|
||||
if not hard_check_only:
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def _check_and_enable_flash_attn_3(
|
||||
cls,
|
||||
config,
|
||||
torch_dtype: Optional[torch.dtype] = None,
|
||||
device_map: Optional[Union[str, dict[str, int]]] = None,
|
||||
check_device_map: bool = True,
|
||||
hard_check_only: bool = False,
|
||||
) -> PretrainedConfig:
|
||||
# If no error raise by this point, we can return `True`
|
||||
return True
|
||||
|
||||
def _flash_attn_3_can_dispatch(self) -> bool:
|
||||
"""
|
||||
Checks the availability of Flash Attention 3 and compatibility with the current model.
|
||||
|
||||
If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_3" so that the model can initialize the correct attention module.
|
||||
"""
|
||||
if not cls._supports_flash_attn_3:
|
||||
# Config always has `torch_dtype` but we need the next line for `no_super_init()` tests
|
||||
torch_dtype = self.config.torch_dtype if hasattr(self.config, "torch_dtype") else torch.get_default_dtype()
|
||||
device_map = self.hf_device_map if hasattr(self, "hf_device_map") else None
|
||||
|
||||
if not self._supports_flash_attn:
|
||||
raise ValueError(
|
||||
f"{cls.__name__} does not support Flash Attention 3.0 yet. Please request to add support where"
|
||||
f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new"
|
||||
f"{self.__class__.__name__} does not support Flash Attention 3.0 yet. Please request to add support where"
|
||||
f" the model is hosted, on its model hub page: https://huggingface.co/{self.config._name_or_path}/discussions/new"
|
||||
" or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new"
|
||||
)
|
||||
|
||||
@@ -2620,22 +2557,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
|
||||
logger.warning_once(
|
||||
"Flash Attention 3 only supports torch.float16 and torch.bfloat16 dtypes, but"
|
||||
f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
|
||||
f" the current dype in {self.__class__.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
|
||||
' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B", attn_implementation="flash_attention_3", torch_dtype=torch.float16)`'
|
||||
)
|
||||
|
||||
if getattr(config, "alibi", False) or getattr(config, "use_alibi", False):
|
||||
if getattr(self.config, "alibi", False) or getattr(self.config, "use_alibi", False):
|
||||
raise ValueError("Model is configured to use ALiBi, which is not supported by Flash Attention 3.")
|
||||
|
||||
# Check for attention dropout, which is incompatible with FA3
|
||||
if hasattr(config, "attention_dropout") and config.attention_dropout > 0:
|
||||
if hasattr(self.config, "attention_dropout") and self.config.attention_dropout > 0:
|
||||
raise ValueError(
|
||||
f"Model has attention_dropout={config.attention_dropout}, which is not supported by Flash Attention 3."
|
||||
f"Model has attention_dropout={self.config.attention_dropout}, which is not supported by Flash Attention 3."
|
||||
)
|
||||
|
||||
# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
|
||||
# or the model may be initialized under the context manager `with torch.device("cuda"):`.
|
||||
if check_device_map and device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]:
|
||||
if device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]:
|
||||
if torch.cuda.is_available():
|
||||
logger.warning_once(
|
||||
"You are attempting to use Flash Attention 3 with a model not initialized on GPU. Make sure to move the model to GPU"
|
||||
@@ -2648,8 +2585,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
"or initialising the model on CPU and then moving it to GPU."
|
||||
)
|
||||
elif (
|
||||
check_device_map
|
||||
and device_map is not None
|
||||
device_map is not None
|
||||
and isinstance(device_map, dict)
|
||||
and ("cpu" in device_map.values() or "disk" in device_map.values())
|
||||
):
|
||||
@@ -2657,21 +2593,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
"You are attempting to use Flash Attention 3 with a model dispatched on CPU or disk. This is not supported. Please make sure to "
|
||||
"initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
|
||||
)
|
||||
if not hard_check_only:
|
||||
config._attn_implementation = "flash_attention_3"
|
||||
return config
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False):
|
||||
def _sdpa_can_dispatch(self, hard_check_only: bool = False) -> bool:
|
||||
"""
|
||||
Checks the availability of SDPA for a given model.
|
||||
|
||||
If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "sdpa" so that the model can initialize the correct attention module.
|
||||
"""
|
||||
if hard_check_only:
|
||||
if not cls._supports_sdpa:
|
||||
if not self._supports_sdpa:
|
||||
raise ValueError(
|
||||
f"{cls.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet."
|
||||
f"{self.__class__.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet."
|
||||
" Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe"
|
||||
' this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
|
||||
)
|
||||
@@ -2680,45 +2613,44 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
"PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1."
|
||||
)
|
||||
|
||||
if not is_torch_sdpa_available() or not cls._supports_sdpa:
|
||||
return config
|
||||
if (
|
||||
torch.version.hip is not None
|
||||
and torch.cuda.device_count() > 1
|
||||
and version.parse(torch.__version__) < version.parse("2.4.1")
|
||||
):
|
||||
logger.warning_once(
|
||||
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
|
||||
)
|
||||
torch.backends.cuda.enable_flash_sdp(False)
|
||||
|
||||
_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
|
||||
if _is_bettertransformer:
|
||||
return config
|
||||
# This means we have `hard_check_only=False` and fallback to eager if SDPA isn't supported
|
||||
_is_bettertransformer = getattr(self, "use_bettertransformer", False)
|
||||
if not is_torch_sdpa_available() or not self._supports_sdpa or _is_bettertransformer:
|
||||
return False
|
||||
|
||||
if not hard_check_only:
|
||||
config._attn_implementation = "sdpa"
|
||||
return config
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _check_and_enable_flex_attn(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
|
||||
def _flex_attn_can_dispatch(self) -> bool:
|
||||
"""
|
||||
Checks the availability of Flex Attention for a given model.
|
||||
|
||||
If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flex_attention" so that the model can initialize the correct attention module.
|
||||
"""
|
||||
if hard_check_only:
|
||||
if not cls._supports_flex_attn:
|
||||
raise ValueError(
|
||||
f"{cls.__name__} does not support an attention implementation through torch's flex_attention."
|
||||
" Please request the support for this architecture: https://github.com/huggingface/transformers/issues/34809."
|
||||
" If you believe this error is a bug, please open an issue in Transformers GitHub repository"
|
||||
' and load your model with the argument `attn_implementation="eager"` meanwhile.'
|
||||
' Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
|
||||
)
|
||||
if not is_torch_flex_attn_available():
|
||||
raise ImportError(
|
||||
"PyTorch Flex Attention requirements in Transformers are not met. Please install torch>=2.5.0."
|
||||
)
|
||||
if not self._supports_flex_attn:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support an attention implementation through torch's flex_attention."
|
||||
" Please request the support for this architecture: https://github.com/huggingface/transformers/issues/34809."
|
||||
" If you believe this error is a bug, please open an issue in Transformers GitHub repository"
|
||||
' and load your model with the argument `attn_implementation="eager"` meanwhile.'
|
||||
' Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
|
||||
)
|
||||
if not is_torch_flex_attn_available():
|
||||
raise ImportError(
|
||||
"PyTorch Flex Attention requirements in Transformers are not met. Please install torch>=2.5.0."
|
||||
)
|
||||
|
||||
if not is_torch_flex_attn_available() or not cls._supports_flex_attn:
|
||||
return config
|
||||
|
||||
if not hard_check_only:
|
||||
config._attn_implementation = "flex_attention"
|
||||
|
||||
return config
|
||||
# If no error raise by this point, we can return `True`
|
||||
return True
|
||||
|
||||
def enable_input_require_grads(self):
|
||||
"""
|
||||
@@ -4803,13 +4735,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
model_init_context = cls.get_init_context(is_quantized, _is_ds_init_called)
|
||||
|
||||
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
|
||||
if not getattr(config, "_attn_implementation_autoset", False):
|
||||
config = cls._autoset_attn_implementation(
|
||||
config,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=device_map,
|
||||
)
|
||||
|
||||
with ContextManagers(model_init_context):
|
||||
# Let's make sure we don't run the init function of buffer modules
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
||||
Reference in New Issue
Block a user