Delete bad rebasing functions (#39672)

* remove outdated stuff

* remove comment

* use register

* remove finally clause (to allow further check if fallback to sdpa)

* general exception

* add wrapper

* revert check

* typo
This commit is contained in:
Cyril Vallez
2025-07-25 18:28:09 +02:00
committed by GitHub
parent a91653561e
commit ddb0546d14

View File

@@ -2094,8 +2094,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
_supports_attention_backend = False _supports_attention_backend = False
_can_record_outputs = None _can_record_outputs = None
# This attribute sets the default parameter to be
@property @property
@torch._dynamo.allow_in_graph @torch._dynamo.allow_in_graph
def can_record_outputs(self) -> dict[str, OutputRecorder]: def can_record_outputs(self) -> dict[str, OutputRecorder]:
@@ -2359,101 +2357,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
return model return model
@classmethod
def _check_attn_implementation(cls, attn_implementation: Union[str, dict]) -> Union[str, dict]:
"""
Checks that the requested attention implementation exists and tries to get the kernel from hub
if `attn_implementation` matches hf kernels pattern.
"""
if isinstance(attn_implementation, str) and re.match(r"^[^/:]+/[^/:]+:[^/:]+$", attn_implementation):
if not is_kernels_available():
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
# Extract repo_id and kernel_name from the string
repo_id, kernel_name = attn_implementation.split(":")
kernel_name = kernel_name.strip()
repo_id = repo_id.strip()
try:
kernel = get_kernel(repo_id)
ALL_ATTENTION_FUNCTIONS.register(f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name))
attn_implementation = f"kernel_{repo_id.replace('/', '_')}"
except FileNotFoundError as e:
logger.warning(
f"Could not find a kernel repository '{repo_id}' compatible with your devicein the hub: {e}. Using eager attention implementation instead."
)
attn_implementation = None # try to dispatch SDPA and fallback eager if not available
except AttributeError:
raise ValueError(
"the kernel function name or class specified in the attn_implementation argument is not valid. \
Please check the documentation for the correct format, \
and check that the kernel exports the class and the function correctly."
)
if (
not isinstance(attn_implementation, dict)
and attn_implementation not in ["eager", None] + ALL_ATTENTION_FUNCTIONS.valid_keys()
):
message = f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
# check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
if cls._supports_flash_attn or getattr(cls, "_supports_flash_attn_2", False):
message += (
', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)'
', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
)
if cls._supports_sdpa:
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
if cls._supports_flex_attn:
message += ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
raise ValueError(message + ".")
return attn_implementation
def set_attention_implementation(self, attn_implementation: Union[str, dict]):
"""
Checks and dispatches to the requested attention implementation.
"""
requested_attn_implementation = self._check_attn_implementation(attn_implementation)
# Composite models consisting of several PretrainedModels can specify attention implementation as a dict where
# keys are sub-config names. But most people will specify one `str` which means that should dispatch it for all sub-models.
# See https://github.com/huggingface/transformers/pull/32238
for key in self.config.sub_configs.keys():
sub_config = getattr(self.config, key)
curr_attn_implementation = (
requested_attn_implementation
if not isinstance(requested_attn_implementation, dict)
else requested_attn_implementation.get(key, None)
)
# For models with backbone sub-config might be not initialized. Set the requested att
# if the config hasn't got any attn pre-set and the requested attn in not `None` (i.e not the default attn)
if (
sub_config is not None
and sub_config._attn_implementation_internal is None
and curr_attn_implementation is not None
):
sub_config._attn_implementation_internal = curr_attn_implementation
if requested_attn_implementation == "flash_attention_3" and self._flash_attn_3_can_dispatch():
self.config._attn_implementation = "flash_attention_3"
if requested_attn_implementation == "flash_attention_2" and self._flash_attn_2_can_dispatch():
self.config._attn_implementation = "flash_attention_2"
elif requested_attn_implementation == "flex_attention" and self._flex_attn_can_dispatch():
self.config._attn_implementation = "flex_attention"
elif (
requested_attn_implementation in [None, "sdpa"]
and not is_torch_xla_available()
and self._sdpa_can_dispatch(hard_check_only=requested_attn_implementation is not None)
):
self.config._attn_implementation = "sdpa"
elif requested_attn_implementation in ALL_ATTENTION_FUNCTIONS.valid_keys():
self.config._attn_implementation = requested_attn_implementation
elif isinstance(requested_attn_implementation, dict):
self.config._attn_implementation = requested_attn_implementation.get("", None)
else:
self.config._attn_implementation = "eager"
self.config._attn_implementation_autoset = True
@classmethod @classmethod
def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype: def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
""" """
@@ -2800,23 +2703,19 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
try: try:
kernel = get_kernel(repo_id) kernel = get_kernel(repo_id)
if hasattr(kernel, "flash_attn_varlen_func"): if hasattr(kernel, "flash_attn_varlen_func"):
ALL_ATTENTION_FUNCTIONS._global_mapping[repo_id] = partial( kernel_function = partial(flash_attention_forward, implementation=kernel)
flash_attention_forward, implementation=kernel
)
elif kernel_name is not None: elif kernel_name is not None:
ALL_ATTENTION_FUNCTIONS[repo_id] = getattr(kernel, kernel_name) kernel_function = getattr(kernel, kernel_name)
ALL_MASK_ATTENTION_FUNCTIONS._global_mapping[repo_id] = ALL_MASK_ATTENTION_FUNCTIONS[ # Register it
"flash_attention_2" ALL_ATTENTION_FUNCTIONS.register(repo_id, kernel_function)
] ALL_MASK_ATTENTION_FUNCTIONS.register(repo_id, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"])
applicable_attn_implementation = repo_id applicable_attn_implementation = repo_id
except FileNotFoundError as e: except Exception as e:
logger.warning_once( logger.warning_once(
f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using " f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using "
"default attention implementation instead (sdpa if available, eager otherwise)." "default attention implementation instead (sdpa if available, eager otherwise)."
) )
applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case
finally:
return applicable_attn_implementation
if applicable_attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys(): if applicable_attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
message = ( message = (
f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are ' f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are '