From ddb0546d145c2f944d94444ec8327571908c280b Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 25 Jul 2025 18:28:09 +0200 Subject: [PATCH] Delete bad rebasing functions (#39672) * remove outdated stuff * remove comment * use register * remove finally clause (to allow further check if fallback to sdpa) * general exception * add wrapper * revert check * typo --- src/transformers/modeling_utils.py | 113 ++--------------------------- 1 file changed, 6 insertions(+), 107 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5c4226ad2c..2c47965099 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2094,8 +2094,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH _supports_attention_backend = False _can_record_outputs = None - # This attribute sets the default parameter to be - @property @torch._dynamo.allow_in_graph def can_record_outputs(self) -> dict[str, OutputRecorder]: @@ -2359,101 +2357,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH return model - @classmethod - def _check_attn_implementation(cls, attn_implementation: Union[str, dict]) -> Union[str, dict]: - """ - Checks that the requested attention implementation exists and tries to get the kernel from hub - if `attn_implementation` matches hf kernels pattern. - """ - if isinstance(attn_implementation, str) and re.match(r"^[^/:]+/[^/:]+:[^/:]+$", attn_implementation): - if not is_kernels_available(): - raise ValueError("kernels is not installed. Please install it with `pip install kernels`.") - - # Extract repo_id and kernel_name from the string - repo_id, kernel_name = attn_implementation.split(":") - kernel_name = kernel_name.strip() - repo_id = repo_id.strip() - - try: - kernel = get_kernel(repo_id) - ALL_ATTENTION_FUNCTIONS.register(f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name)) - attn_implementation = f"kernel_{repo_id.replace('/', '_')}" - except FileNotFoundError as e: - logger.warning( - f"Could not find a kernel repository '{repo_id}' compatible with your devicein the hub: {e}. Using eager attention implementation instead." - ) - attn_implementation = None # try to dispatch SDPA and fallback eager if not available - except AttributeError: - raise ValueError( - "the kernel function name or class specified in the attn_implementation argument is not valid. \ - Please check the documentation for the correct format, \ - and check that the kernel exports the class and the function correctly." - ) - if ( - not isinstance(attn_implementation, dict) - and attn_implementation not in ["eager", None] + ALL_ATTENTION_FUNCTIONS.valid_keys() - ): - message = f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)' - # check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases - if cls._supports_flash_attn or getattr(cls, "_supports_flash_attn_2", False): - message += ( - ', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)' - ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' - ) - if cls._supports_sdpa: - message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)' - if cls._supports_flex_attn: - message += ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)' - raise ValueError(message + ".") - - return attn_implementation - - def set_attention_implementation(self, attn_implementation: Union[str, dict]): - """ - Checks and dispatches to the requested attention implementation. - """ - requested_attn_implementation = self._check_attn_implementation(attn_implementation) - - # Composite models consisting of several PretrainedModels can specify attention implementation as a dict where - # keys are sub-config names. But most people will specify one `str` which means that should dispatch it for all sub-models. - # See https://github.com/huggingface/transformers/pull/32238 - for key in self.config.sub_configs.keys(): - sub_config = getattr(self.config, key) - curr_attn_implementation = ( - requested_attn_implementation - if not isinstance(requested_attn_implementation, dict) - else requested_attn_implementation.get(key, None) - ) - # For models with backbone sub-config might be not initialized. Set the requested att - # if the config hasn't got any attn pre-set and the requested attn in not `None` (i.e not the default attn) - if ( - sub_config is not None - and sub_config._attn_implementation_internal is None - and curr_attn_implementation is not None - ): - sub_config._attn_implementation_internal = curr_attn_implementation - - if requested_attn_implementation == "flash_attention_3" and self._flash_attn_3_can_dispatch(): - self.config._attn_implementation = "flash_attention_3" - if requested_attn_implementation == "flash_attention_2" and self._flash_attn_2_can_dispatch(): - self.config._attn_implementation = "flash_attention_2" - elif requested_attn_implementation == "flex_attention" and self._flex_attn_can_dispatch(): - self.config._attn_implementation = "flex_attention" - elif ( - requested_attn_implementation in [None, "sdpa"] - and not is_torch_xla_available() - and self._sdpa_can_dispatch(hard_check_only=requested_attn_implementation is not None) - ): - self.config._attn_implementation = "sdpa" - elif requested_attn_implementation in ALL_ATTENTION_FUNCTIONS.valid_keys(): - self.config._attn_implementation = requested_attn_implementation - elif isinstance(requested_attn_implementation, dict): - self.config._attn_implementation = requested_attn_implementation.get("", None) - else: - self.config._attn_implementation = "eager" - - self.config._attn_implementation_autoset = True - @classmethod def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype: """ @@ -2800,23 +2703,19 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH try: kernel = get_kernel(repo_id) if hasattr(kernel, "flash_attn_varlen_func"): - ALL_ATTENTION_FUNCTIONS._global_mapping[repo_id] = partial( - flash_attention_forward, implementation=kernel - ) + kernel_function = partial(flash_attention_forward, implementation=kernel) elif kernel_name is not None: - ALL_ATTENTION_FUNCTIONS[repo_id] = getattr(kernel, kernel_name) - ALL_MASK_ATTENTION_FUNCTIONS._global_mapping[repo_id] = ALL_MASK_ATTENTION_FUNCTIONS[ - "flash_attention_2" - ] + kernel_function = getattr(kernel, kernel_name) + # Register it + ALL_ATTENTION_FUNCTIONS.register(repo_id, kernel_function) + ALL_MASK_ATTENTION_FUNCTIONS.register(repo_id, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]) applicable_attn_implementation = repo_id - except FileNotFoundError as e: + except Exception as e: logger.warning_once( f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using " "default attention implementation instead (sdpa if available, eager otherwise)." ) applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case - finally: - return applicable_attn_implementation if applicable_attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys(): message = ( f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are '