diff --git a/docs/source/en/attention_interface.md b/docs/source/en/attention_interface.md index d78e21413e..034686ad2c 100644 --- a/docs/source/en/attention_interface.md +++ b/docs/source/en/attention_interface.md @@ -60,11 +60,11 @@ You will see it prints "I just entered the attention computation" as many times ## Dynamically switching attention function -You could dynamically change the model's attention function as well, by overriding the `config._attn_implementation` field: +You could dynamically change the model's attention function as well: ```python # Back to use original sdpa implementation -model.config._attn_implementation = "sdpa" +model.set_attn_implementation("sdpa") model(torch.ones(1, 5, dtype=int)) ``` diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 32a3b57956..243622d895 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -323,9 +323,8 @@ class PretrainedConfig(PushToHubMixin): self._name_or_path = str(kwargs.pop("name_or_path", "")) self._commit_hash = kwargs.pop("_commit_hash", None) - # Attention implementation to use, if relevant. - self._attn_implementation_internal = kwargs.pop("attn_implementation", None) - self._attn_implementation_autoset = False + # Attention implementation to use, if relevant (it sets it recursively on sub-configs) + self._attn_implementation = kwargs.pop("attn_implementation", None) # Drop the transformers version info self.transformers_version = kwargs.pop("transformers_version", None) @@ -370,8 +369,11 @@ class PretrainedConfig(PushToHubMixin): return self._output_attentions @output_attentions.setter - def output_attentions(self, value): - if value is True and self._attn_implementation != "eager": + def output_attentions(self, value: bool): + # If we set `output_attentions` explictily before the attn implementation, dispatch eager + if value and self._attn_implementation is None: + self._attn_implementation = "eager" + if value and self._attn_implementation != "eager": raise ValueError( "The `output_attentions` attribute is not supported when using the `attn_implementation` set to " f"{self._attn_implementation}. Please set it to 'eager' instead." @@ -402,19 +404,23 @@ class PretrainedConfig(PushToHubMixin): @property def _attn_implementation(self): - # This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.) - if hasattr(self, "_attn_implementation_internal"): - if self._attn_implementation_internal is None: - # `config.attn_implementation` should never be None, for backward compatibility. - return "eager" - else: - return self._attn_implementation_internal - else: - return "eager" + return self._attn_implementation_internal @_attn_implementation.setter - def _attn_implementation(self, value): - self._attn_implementation_internal = value + def _attn_implementation(self, value: Optional[Union[str, dict]]): + """We set it recursively on the sub-configs as well""" + # Set if for current config + attn_implementation = value if not isinstance(value, dict) else value.get("", self._attn_implementation) + self._attn_implementation_internal = attn_implementation + + # Set it recursively on the subconfigs + for subconfig_key in self.sub_configs: + subconfig = getattr(self, subconfig_key, None) + if subconfig is not None: + sub_implementation = ( + value if not isinstance(value, dict) else value.get(subconfig_key, subconfig._attn_implementation) + ) + subconfig._attn_implementation = sub_implementation def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): """ @@ -1053,8 +1059,6 @@ class PretrainedConfig(PushToHubMixin): del d["_commit_hash"] if "_attn_implementation_internal" in d: del d["_attn_implementation_internal"] - if "_attn_implementation_autoset" in d: - del d["_attn_implementation_autoset"] # Do not serialize `base_model_tp_plan` for now if "base_model_tp_plan" in d: del d["base_model_tp_plan"] diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a97ba8511d..9d2f70a497 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -24,6 +24,7 @@ import json import os import re import shutil +import sys import tempfile import warnings from abc import abstractmethod @@ -2094,12 +2095,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi ) self.config = config - # The `hasattr` here is used as some Transformers tests for some reason do not call - # PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model) - if hasattr(config, "_attn_implementation_internal") and not getattr( - config, "_attn_implementation_autoset", False - ): - self.set_attention_implementation(self.config._attn_implementation_internal) + # Check the attention implementation is supported, or set it if not yet set (on the internal attr, to avoid + # setting it recursively) + self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation( + self.config._attn_implementation, is_init_check=True + ) # for initialization of the loss loss_type = self.__class__.__name__ @@ -2244,14 +2244,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi if torch_dtype is not None: dtype_orig = cls._set_default_torch_dtype(torch_dtype) - config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config. - - if config._attn_implementation_internal is not None: - # In this case, the config has been created with the attn_implementation set by the user, which we should respect. - attn_implementation = config._attn_implementation_internal - else: - attn_implementation = None - config._attn_implementation = kwargs.pop("attn_implementation", attn_implementation) + # If passing `attn_implementation` as kwargs, respect it (it will be applied recursively on subconfigs) + if "attn_implementation" in kwargs: + config._attn_implementation = kwargs.pop("attn_implementation") if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called: logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") @@ -2272,101 +2267,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi return model - @classmethod - def _check_attn_implementation(cls, attn_implementation: Union[dict, str]) -> Union[dict, str]: - """ - Checks that the requested attention implementation exists and tries to get the kernel from hub - if `attn_implementation` matches hf kernels pattern. - """ - if isinstance(attn_implementation, str) and re.match(r"^[^/:]+/[^/:]+:[^/:]+$", attn_implementation): - if not is_kernels_available(): - raise ValueError("kernels is not installed. Please install it with `pip install kernels`.") - - # Extract repo_id and kernel_name from the string - repo_id, kernel_name = attn_implementation.split(":") - kernel_name = kernel_name.strip() - repo_id = repo_id.strip() - - try: - kernel = get_kernel(repo_id) - ALL_ATTENTION_FUNCTIONS.register(f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name)) - attn_implementation = f"kernel_{repo_id.replace('/', '_')}" - except FileNotFoundError as e: - logger.warning( - f"Could not find a kernel repository '{repo_id}' compatible with your devicein the hub: {e}. Using eager attention implementation instead." - ) - attn_implementation = None # try to dispatch SDPA and fallback eager if not available - except AttributeError: - raise ValueError( - "the kernel function name or class specified in the attn_implementation argument is not valid. \ - Please check the documentation for the correct format, \ - and check that the kernel exports the class and the function correctly." - ) - if ( - not isinstance(attn_implementation, dict) - and attn_implementation not in ["eager", None] + ALL_ATTENTION_FUNCTIONS.valid_keys() - ): - message = f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)' - # check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases - if cls._supports_flash_attn or getattr(cls, "_supports_flash_attn_2", False): - message += ( - ', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)' - ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' - ) - if cls._supports_sdpa: - message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)' - if cls._supports_flex_attn: - message += ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)' - raise ValueError(message + ".") - - return attn_implementation - - def set_attention_implementation(self, attn_implementation: Union[dict, str]): - """ - Checks and dispatches to the requested attention implementation. - """ - requested_attn_implementation = self._check_attn_implementation(attn_implementation) - - # Composite models consisting of several PretrainedModels can specify attention implementation as a dict where - # keys are sub-config names. But most people will specify one `str` which means that should dispatch it for all sub-models. - # See https://github.com/huggingface/transformers/pull/32238 - for key in self.config.sub_configs.keys(): - sub_config = getattr(self.config, key) - curr_attn_implementation = ( - requested_attn_implementation - if not isinstance(requested_attn_implementation, dict) - else requested_attn_implementation.get(key, None) - ) - # For models with backbone sub-config might be not initialized. Set the requested att - # if the config hasn't got any attn pre-set and the requested attn in not `None` (i.e not the default attn) - if ( - sub_config is not None - and sub_config._attn_implementation_internal is None - and curr_attn_implementation is not None - ): - sub_config._attn_implementation_internal = curr_attn_implementation - - if requested_attn_implementation == "flash_attention_3" and self._flash_attn_3_can_dispatch(): - self.config._attn_implementation = "flash_attention_3" - if requested_attn_implementation == "flash_attention_2" and self._flash_attn_2_can_dispatch(): - self.config._attn_implementation = "flash_attention_2" - elif requested_attn_implementation == "flex_attention" and self._flex_attn_can_dispatch(): - self.config._attn_implementation = "flex_attention" - elif ( - requested_attn_implementation in [None, "sdpa"] - and not is_torch_xla_available() - and self._sdpa_can_dispatch(hard_check_only=requested_attn_implementation is not None) - ): - self.config._attn_implementation = "sdpa" - elif requested_attn_implementation in ALL_ATTENTION_FUNCTIONS.valid_keys(): - self.config._attn_implementation = requested_attn_implementation - elif isinstance(requested_attn_implementation, dict): - self.config._attn_implementation = requested_attn_implementation.get("", None) - else: - self.config._attn_implementation = "eager" - - self.config._attn_implementation_autoset = True - @classmethod def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype: """ @@ -2439,15 +2339,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi # Otherwise, can't generate return False - def _flash_attn_2_can_dispatch(self) -> bool: + def _flash_attn_2_can_dispatch(self, is_init_check: bool = False) -> bool: """ - Checks the availability of Flash Attention 2 and compatibility with the current model. + Check the availability of Flash Attention 2 for a given model. - If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module. + Args: + is_init_check (`bool`, *optional*): + Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are + fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses + BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early + before instantiating the full models if we know that the model does not support the requested attention. """ - # Config always has `torch_dtype` but we need the next line for `no_super_init()` tests - torch_dtype = self.config.torch_dtype if hasattr(self.config, "torch_dtype") else torch.get_default_dtype() - device_map = self.hf_device_map if hasattr(self, "hf_device_map") else None + torch_dtype = self.config.torch_dtype # check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases if not (self._supports_flash_attn or getattr(self, "_supports_flash_attn_2", False)): @@ -2486,68 +2389,62 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi else: raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}") - _is_bettertransformer = getattr(self, "use_bettertransformer", False) - if _is_bettertransformer: - raise ValueError( - "Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()" - ) - if torch_dtype is None: logger.warning_once( - "You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour" + "You are attempting to use Flash Attention 2 without specifying a torch dtype. This might lead to unexpected behaviour" ) elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]: logger.warning_once( - "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but" + "Flash Attention 2 only supports torch.float16 and torch.bfloat16 dtypes, but" f" the current dype in {self.__class__.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator," ' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`' ) - # The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called, - # or the model may be initialized under the context manager `with torch.device("cuda"):`. - if device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]: - if torch.cuda.is_available(): - logger.warning_once( - "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU" - " after initializing it on CPU with `model.to('cuda')`." - ) - elif is_torch_mlu_available(): - logger.warning_once( - "You are attempting to use Flash Attention 2.0 with a model not initialized on MLU. Make sure to move the model to MLU" - " after initializing it on CPU with `model.to('mlu')`." - ) - else: + # With the early check, the parameters are not yet initalized correctly + if not is_init_check: + if getattr(self, "use_bettertransformer", False): raise ValueError( - "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU and with no GPU available. " - "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map " - "or initialising the model on CPU and then moving it to GPU." + "Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()" ) - elif ( - device_map is not None - and isinstance(device_map, dict) - and ("cpu" in device_map.values() or "disk" in device_map.values()) - ): - raise ValueError( - "You are attempting to use Flash Attention 2.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to " - "initialise the model on a GPU by passing a device_map that contains only GPU devices as keys." - ) + + param_devices = list({param.device for param in self.parameters()}) + if len(param_devices) == 1 and param_devices[0].type == "cpu": + if torch.cuda.is_available(): + logger.warning_once( + "You are attempting to use Flash Attention 2 with a model not initialized on GPU. Make sure to move the model to GPU" + " after initializing it on CPU with `model.to('cuda')`." + ) + elif is_torch_mlu_available(): + logger.warning_once( + "You are attempting to use Flash Attention 2 with a model not initialized on MLU. Make sure to move the model to MLU" + " after initializing it on CPU with `model.to('mlu')`." + ) + else: + raise ValueError( + "You are attempting to use Flash Attention 2 with a model not initialized on GPU and with no GPU available. " + "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map " + "or initialising the model on CPU and then moving it to GPU." + ) # If no error raise by this point, we can return `True` return True - def _flash_attn_3_can_dispatch(self) -> bool: + def _flash_attn_3_can_dispatch(self, is_init_check: bool = False) -> bool: """ - Checks the availability of Flash Attention 3 and compatibility with the current model. + Check the availability of Flash Attention 3 for a given model. - If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_3" so that the model can initialize the correct attention module. + Args: + is_init_check (`bool`, *optional*): + Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are + fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses + BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early + before instantiating the full models if we know that the model does not support the requested attention. """ - # Config always has `torch_dtype` but we need the next line for `no_super_init()` tests - torch_dtype = self.config.torch_dtype if hasattr(self.config, "torch_dtype") else torch.get_default_dtype() - device_map = self.hf_device_map if hasattr(self, "hf_device_map") else None + torch_dtype = self.config.torch_dtype if not self._supports_flash_attn: raise ValueError( - f"{self.__class__.__name__} does not support Flash Attention 3.0 yet. Please request to add support where" + f"{self.__class__.__name__} does not support Flash Attention 3 yet. Please request to add support where" f" the model is hosted, on its model hub page: https://huggingface.co/{self.config._name_or_path}/discussions/new" " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new" ) @@ -2591,48 +2488,43 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi f"Model has attention_dropout={self.config.attention_dropout}, which is not supported by Flash Attention 3." ) - # The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called, - # or the model may be initialized under the context manager `with torch.device("cuda"):`. - if device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]: - if torch.cuda.is_available(): - logger.warning_once( - "You are attempting to use Flash Attention 3 with a model not initialized on GPU. Make sure to move the model to GPU" - " after initializing it on CPU with `model.to('cuda')`." - ) - else: - raise ValueError( - "You are attempting to use Flash Attention 3 with a model not initialized on GPU and with no GPU available. " - "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map " - "or initialising the model on CPU and then moving it to GPU." - ) - elif ( - device_map is not None - and isinstance(device_map, dict) - and ("cpu" in device_map.values() or "disk" in device_map.values()) - ): - raise ValueError( - "You are attempting to use Flash Attention 3 with a model dispatched on CPU or disk. This is not supported. Please make sure to " - "initialise the model on a GPU by passing a device_map that contains only GPU devices as keys." - ) + # With the early check, the parameters are not yet initalized correctly + if not is_init_check: + param_devices = list({param.device for param in self.parameters()}) + if len(param_devices) == 1 and param_devices[0].type == "cpu": + if torch.cuda.is_available(): + logger.warning_once( + "You are attempting to use Flash Attention 3 with a model not initialized on GPU. Make sure to move the model to GPU" + " after initializing it on CPU with `model.to('cuda')`." + ) + else: + raise ValueError( + "You are attempting to use Flash Attention 3 with a model not initialized on GPU and with no GPU available. " + "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map " + "or initialising the model on CPU and then moving it to GPU." + ) + return True - def _sdpa_can_dispatch(self, hard_check_only: bool = False) -> bool: + def _sdpa_can_dispatch(self, is_init_check: bool = False) -> bool: """ - Checks the availability of SDPA for a given model. + Check the availability of SDPA for a given model. - If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "sdpa" so that the model can initialize the correct attention module. + Args: + is_init_check (`bool`, *optional*): + Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are + fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses + BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early + before instantiating the full models if we know that the model does not support the requested attention. """ - if hard_check_only: - if not self._supports_sdpa: - raise ValueError( - f"{self.__class__.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet." - " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe" - ' this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`' - ) - if not is_torch_sdpa_available(): - raise ImportError( - "PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1." - ) + if not self._supports_sdpa: + raise ValueError( + f"{self.__class__.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet." + " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe" + ' this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`' + ) + if not is_torch_sdpa_available(): + raise ImportError("PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1.") if ( torch.version.hip is not None @@ -2644,18 +2536,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi ) torch.backends.cuda.enable_flash_sdp(False) - # This means we have `hard_check_only=False` and fallback to eager if SDPA isn't supported - _is_bettertransformer = getattr(self, "use_bettertransformer", False) - if not is_torch_sdpa_available() or not self._supports_sdpa or _is_bettertransformer: - return False + if not is_init_check: + if getattr(self, "use_bettertransformer", False): + raise ValueError( + "SDPA and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()" + ) return True - def _flex_attn_can_dispatch(self) -> bool: + def _flex_attn_can_dispatch(self, is_init_check: bool = False) -> bool: """ - Checks the availability of Flex Attention for a given model. + Check the availability of Flex Attention for a given model. - If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flex_attention" so that the model can initialize the correct attention module. + Args: + is_init_check (`bool`, *optional*): + Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are + fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses + BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early + before instantiating the full models if we know that the model does not support the requested attention. """ if not self._supports_flex_attn: raise ValueError( @@ -2670,9 +2568,190 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi "PyTorch Flex Attention requirements in Transformers are not met. Please install torch>=2.5.0." ) + if not is_init_check: + if getattr(self, "use_bettertransformer", False): + raise ValueError( + "FlexAttention and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()" + ) + # If no error raise by this point, we can return `True` return True + def _check_and_adjust_attn_implementation( + self, attn_implementation: Optional[str], is_init_check: bool = False + ) -> str: + """ + Check that the `attn_implementation` exists and is supported by the models, and try to get the kernel from hub if + it matches hf kernels pattern. + + Args: + attn_implementation (`str` or `None`): + The attention implementation to check for existence/validity. + is_init_check (`bool`, *optional*): + Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are + fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses + BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early + before instantiating the full models if we know that the model does not support the requested attention. + + Returns: + `str`: The final attention implementation to use, including potential fallbacks from sdpa to eager, or from + None to sdpa (to potentially eager). + """ + applicable_attn_implementation = "sdpa" if attn_implementation is None else attn_implementation + if re.match(r"^[^/:]+/[^/:]+:[^/:]+$", applicable_attn_implementation): + if not is_kernels_available(): + raise ValueError("kernels is not installed. Please install it with `pip install kernels`.") + + # Extract repo_id and kernel_name from the string + repo_id, kernel_name = applicable_attn_implementation.split(":") + kernel_name = kernel_name.strip() + repo_id = repo_id.strip() + + try: + kernel = get_kernel(repo_id) + ALL_ATTENTION_FUNCTIONS.register(f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name)) + applicable_attn_implementation = f"kernel_{repo_id.replace('/', '_')}" + except FileNotFoundError as e: + logger.warning_once( + f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using " + "default attention implementation instead (sdpa if available, eager otherwise)." + ) + applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case + except AttributeError: + raise ValueError( + "the kernel function name or class specified in the attn_implementation argument is not valid. Please check " + "the documentation for the correct format, and check that the kernel exports the class and the function correctly." + ) + if applicable_attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys(): + message = ( + f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are ' + '`attn_implementation="eager"` (manual attention implementation)' + ) + # check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases + if self._supports_flash_attn or getattr(self, "_supports_flash_attn_2", False): + message += ( + ', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)' + ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' + ) + if self._supports_sdpa: + message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)' + if self._supports_flex_attn: + message += ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)' + raise ValueError(message + ".") + + # Perform relevant checks + if applicable_attn_implementation == "flash_attention_2": + self._flash_attn_2_can_dispatch(is_init_check) + elif applicable_attn_implementation == "flash_attention_3": + self._flash_attn_3_can_dispatch(is_init_check) + elif applicable_attn_implementation == "flex_attention": + self._flex_attn_can_dispatch(is_init_check) + elif applicable_attn_implementation == "sdpa": + # Sdpa is the default, so we try it and fallback to eager otherwise when not possible + try: + self._sdpa_can_dispatch(is_init_check) + except (ValueError, ImportError) as e: + # In this case, sdpa was requested explicitly, but we can't use it, so let's raise + if attn_implementation == "sdpa": + raise e + applicable_attn_implementation = "eager" + + return applicable_attn_implementation + + @classmethod + def _can_set_attn_implementation(cls) -> bool: + """Detect whether the class supports setting its attention implementation dynamically. It is an ugly check based on + opening the file, but avoids maintaining yet another property flag. + """ + class_file = sys.modules[cls.__module__].__file__ + with open(class_file, "r") as f: + code = f.read() + # heuristic -> if we find those patterns, the model uses the correct interface + return ( + "eager_attention_forward" in code and "ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]" in code + ) + + def set_attn_implementation(self, attn_implementation: Union[str, dict]): + """ + Set the requested `attn_implementation` for this model. + + Args: + attn_implementation (`str` or `dict`): + The attention implementation to set for this model. It can be either a `str`, in which case it will be + dispatched to all submodels if relevant, or a `dict` where keys are the sub_configs name, in which case each + submodel will dispatch the corresponding value. + """ + requested_implementation = ( + attn_implementation + if not isinstance(attn_implementation, dict) + else attn_implementation.get("", self.config._attn_implementation) + ) + + # At this point, the model was already instantiated, so instead of crashing on bad value, let's simply + # warn the user that the requested value is not working + if requested_implementation != self.config._attn_implementation: + # In this case, raise + if not self._can_set_attn_implementation(): + logger.warning( + f"{self.__class__.__name__} does not support setting its attention implementation dynamically, because it " + "does not follow the functional approach based on AttentionInterface " + "(see https://huggingface.co/docs/transformers/en/attention_interface)" + ) + else: + try: + applicable_attn_implementation = self._check_and_adjust_attn_implementation( + requested_implementation, is_init_check=False + ) + # Apply the change (on the internal attr, to avoid setting it recursively) + self.config._attn_implementation_internal = applicable_attn_implementation + except (ValueError, ImportError) as e: + logger.warning( + f"Impossible to set the requested `attn_implementation`. The following error was captured: {str(e)}" + ) + + subconfigs_changed = set() + # Apply it to all submodels as well + for submodule in self.modules(): + # We found a submodel (which is not self) with a different config (otherwise, it may be the same "actual model", + # e.g. ForCausalLM has a Model inside, but no need to check it again) + if ( + submodule is not self + and isinstance(submodule, PreTrainedModel) + and submodule.config.__class__ != self.config.__class__ + ): + sub_implementation = attn_implementation + if isinstance(attn_implementation, dict): + for subconfig_key in self.config.sub_configs: + # We need to check for exact object match here, with `is` + if getattr(self.config, subconfig_key) is submodule.config: + sub_implementation = attn_implementation.get( + subconfig_key, submodule.config._attn_implementation + ) + break + submodule.set_attn_implementation(sub_implementation) + subconfigs_changed.add(submodule.config.__class__) + + # We need this as some old and badly designed models use subconfigs without declaring the corresponding modules as PreTrainedModel + for subconfig_key in self.config.sub_configs: + subconfig = getattr(self.config, subconfig_key) + requested_implementation = ( + attn_implementation + if not isinstance(attn_implementation, dict) + else attn_implementation.get(subconfig_key, subconfig._attn_implementation) + ) + # This means we did not perform any check above for this particular subconfig -> set it in the dark if it is registered + if ( + subconfig.__class__ not in subconfigs_changed + and requested_implementation != subconfig._attn_implementation + and requested_implementation in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys() + ): + subconfig._attn_implementation_internal = requested_implementation + logger.warning( + f"We set the attention implementation for the sub-config `{subconfig_key}` to `{requested_implementation}` " + "without finding the associated sub-model. For this reason we could not check if the model supports it. " + "You may encounter undefined behavior." + ) + def enable_input_require_grads(self): """ Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping @@ -4601,21 +4680,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi if "gguf_file" in model_kwargs: model_kwargs.pop("gguf_file") else: - # In case one passes a config to `from_pretrained` + "attn_implementation" - # override the `_attn_implementation` attribute to `attn_implementation` of the kwargs - # Please see: https://github.com/huggingface/transformers/issues/28038 - - # Overwrite `config._attn_implementation` by the one from the kwargs --> in auto-factory - # we pop attn_implementation from the kwargs but this handles the case where users - # passes manually the config to `from_pretrained`. config = copy.deepcopy(config) - - kwarg_attn_imp = kwargs.pop("attn_implementation", None) - if kwarg_attn_imp is not None: - config._attn_implementation = kwarg_attn_imp - model_kwargs = kwargs + # Because some composite configs call super().__init__ before instantiating the sub-configs, we need this call + # to correctly redispatch recursively if the kwarg is provided + if "attn_implementation" in kwargs: + config._attn_implementation = kwargs.pop("attn_implementation") + transformers_explicit_filename = getattr(config, "transformers_weights", None) if transformers_explicit_filename is not None: diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 9f44323fe4..296d8a5d80 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -1226,6 +1226,8 @@ class AltCLIPModel(AltCLIPPreTrainedModel): text_config = config.text_config vision_config = config.vision_config + # The module using it is not a PreTrainedModel subclass so we need this + vision_config._attn_implementation = config._attn_implementation self.projection_dim = config.projection_dim self.text_embed_dim = text_config.project_dim diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 1e43e754bb..4584373241 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -14,7 +14,6 @@ # limitations under the License. """PyTorch BART model.""" -import copy import math import warnings from typing import Callable, Optional, Union @@ -1842,7 +1841,6 @@ class BartForCausalLM(BartPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): - config = copy.deepcopy(config) config.is_decoder = True config.is_encoder_decoder = False super().__init__(config) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 2220efc887..8146d7b018 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -14,7 +14,6 @@ # limitations under the License. """PyTorch BigBirdPegasus model.""" -import copy import math from typing import Callable, Optional, Union @@ -2923,7 +2922,6 @@ class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): - config = copy.deepcopy(config) config.is_decoder = True config.is_encoder_decoder = False super().__init__(config) diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index c4bb3b6e19..3b9467e6e2 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -14,7 +14,6 @@ # limitations under the License. """PyTorch Blenderbot model.""" -import copy import math import os import warnings @@ -1488,7 +1487,6 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): - config = copy.deepcopy(config) config.is_decoder = True config.is_encoder_decoder = False super().__init__(config) diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 74e5d0767a..7f7701d9da 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -14,7 +14,6 @@ # limitations under the License. """PyTorch BlenderbotSmall model.""" -import copy import math from typing import Callable, Optional, Union @@ -1447,7 +1446,6 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel, GenerationMixin _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): - config = copy.deepcopy(config) config.is_decoder = True config.is_encoder_decoder = False super().__init__(config) diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 6fcc04a940..cf134e22c6 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -1046,6 +1046,8 @@ class ChineseCLIPModel(ChineseCLIPPreTrainedModel): text_config = config.text_config vision_config = config.vision_config + # The module using it is not a PreTrainedModel subclass so we need this + vision_config._attn_implementation = config._attn_implementation self.projection_dim = config.projection_dim self.text_embed_dim = text_config.hidden_size diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index 46c7b1fcf2..34d19ffaf3 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -828,6 +828,10 @@ class CLIPSegModel(CLIPSegPreTrainedModel): text_config = config.text_config vision_config = config.vision_config + # The module using it is not a PreTrainedModel subclass so we need this + text_config._attn_implementation = config._attn_implementation + # The module using it is not a PreTrainedModel subclass so we need this + vision_config._attn_implementation = config._attn_implementation self.projection_dim = config.projection_dim self.text_embed_dim = text_config.hidden_size diff --git a/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py b/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py index 10664a8fef..fac33b0818 100644 --- a/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py +++ b/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py @@ -14,7 +14,6 @@ # limitations under the License. """PyTorch GPTSANJapanese model.""" -import copy from typing import Optional, Union import torch @@ -849,7 +848,6 @@ class GPTSanJapaneseModel(GPTSanJapanesePreTrainedModel): def __init__(self, config: GPTSanJapaneseConfig): super().__init__(config) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model) - self.config = copy.deepcopy(config) self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) self.last_project = nn.Linear(config.d_model, config.d_model, bias=True) self.act = ACT2FN["swish"] diff --git a/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py index 1012c9537a..aa248038f6 100755 --- a/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py @@ -14,7 +14,6 @@ # limitations under the License. """PyTorch Speech2Text2 model.""" -import copy import math from typing import Optional, Union @@ -682,7 +681,6 @@ class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): - config = copy.deepcopy(config) config.is_decoder = True config.is_encoder_decoder = False super().__init__(config) diff --git a/src/transformers/models/dpt/configuration_dpt.py b/src/transformers/models/dpt/configuration_dpt.py index be789b9a5b..70e46f2320 100644 --- a/src/transformers/models/dpt/configuration_dpt.py +++ b/src/transformers/models/dpt/configuration_dpt.py @@ -294,7 +294,11 @@ class DPTConfig(PretrainedConfig): @property def sub_configs(self): - return {"backbone_config": type(self.backbone_config)} if self.backbone_config is not None else {} + return ( + {"backbone_config": type(self.backbone_config)} + if getattr(self, "backbone_config", None) is not None + else {} + ) __all__ = ["DPTConfig"] diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 9ac09ca6d0..ec752c2b1d 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -942,6 +942,8 @@ class IdeficsModel(IdeficsPreTrainedModel): self.image_size = config.vision_config.image_size self.vision_config = config.vision_config + # The module using it is not a PreTrainedModel subclass so we need this + self.vision_config._attn_implementation = config._attn_implementation self.vision_model = IdeficsVisionTransformer(config.vision_config) # Perceiver Resampler diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 052bbe3c1a..3bcc64db47 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -1601,7 +1601,6 @@ class MarianForCausalLM(MarianPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): - config = copy.deepcopy(config) config.is_decoder = True config.is_encoder_decoder = False super().__init__(config) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 63b6ec0cb1..c02de1d4fe 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -14,7 +14,6 @@ # limitations under the License. """PyTorch MBART model.""" -import copy import math from typing import Callable, Optional, Union @@ -1799,7 +1798,6 @@ class MBartForCausalLM(MBartPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): - config = copy.deepcopy(config) config.is_decoder = True config.is_encoder_decoder = False super().__init__(config) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index f30bc3073e..6b9be40ebf 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -1882,7 +1882,7 @@ class MT5EncoderModel(MT5PreTrainedModel): super().__init__(config) self.shared = nn.Embedding(config.vocab_size, config.d_model) - encoder_config = copy.deepcopy(config) + encoder_config = config encoder_config.use_cache = False encoder_config.is_encoder_decoder = False self.encoder = MT5Stack(encoder_config, self.shared) diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index fd8f19eccc..370afff391 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -14,7 +14,6 @@ # limitations under the License. """PyTorch MVP model.""" -import copy import math from typing import Optional, Union @@ -1680,7 +1679,6 @@ class MvpForCausalLM(MvpPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): - config = copy.deepcopy(config) config.is_decoder = True config.is_encoder_decoder = False super().__init__(config) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 52cb126e51..07cb8835bb 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -19,7 +19,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import math from typing import Callable, Optional, Union @@ -1630,7 +1629,6 @@ class PLBartForCausalLM(PLBartPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): - config = copy.deepcopy(config) config.is_decoder = True config.is_encoder_decoder = False super().__init__(config) diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 10eeadd766..6ddf829ad4 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -1689,18 +1689,10 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo def __init__(self, config: Qwen2_5OmniThinkerConfig): super().__init__(config) - self.audio_tower = Qwen2_5OmniAudioEncoder._from_config( - config.audio_config, attn_implementation=config._attn_implementation - ) - - self.visual = Qwen2_5OmniVisionEncoder._from_config( - config.vision_config, attn_implementation=config._attn_implementation - ) - + self.audio_tower = Qwen2_5OmniAudioEncoder._from_config(config.audio_config) + self.visual = Qwen2_5OmniVisionEncoder._from_config(config.vision_config) self.vocab_size = config.text_config.vocab_size - self.model = Qwen2_5OmniThinkerTextModel._from_config( - config.text_config, attn_implementation=config._attn_implementation - ) + self.model = Qwen2_5OmniThinkerTextModel._from_config(config.text_config) self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.spatial_merge_size = config.vision_config.spatial_merge_size @@ -2953,7 +2945,6 @@ class DiTAttention(nn.Module): self.heads = config.num_attention_heads self.inner_dim = config.head_dim * config.num_attention_heads self.dropout = config.dropout - self._attn_implementation = config._attn_implementation self.is_causal = False self.to_q = nn.Linear(config.hidden_size, self.inner_dim) @@ -2987,7 +2978,7 @@ class DiTAttention(nn.Module): cos, sin = position_embeddings query[:, :1], key[:, :1] = apply_rotary_pos_emb(query[:, :1], key[:, :1], cos, sin) - attention_interface = ALL_ATTENTION_FUNCTIONS[self._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attention_weights, _ = attention_interface( self, query, diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 34e61e0b26..e315a84583 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -2144,18 +2144,10 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo def __init__(self, config: Qwen2_5OmniThinkerConfig): super().__init__(config) - self.audio_tower = Qwen2_5OmniAudioEncoder._from_config( - config.audio_config, attn_implementation=config._attn_implementation - ) - - self.visual = Qwen2_5OmniVisionEncoder._from_config( - config.vision_config, attn_implementation=config._attn_implementation - ) - + self.audio_tower = Qwen2_5OmniAudioEncoder._from_config(config.audio_config) + self.visual = Qwen2_5OmniVisionEncoder._from_config(config.vision_config) self.vocab_size = config.text_config.vocab_size - self.model = Qwen2_5OmniThinkerTextModel._from_config( - config.text_config, attn_implementation=config._attn_implementation - ) + self.model = Qwen2_5OmniThinkerTextModel._from_config(config.text_config) self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.spatial_merge_size = config.vision_config.spatial_merge_size @@ -3270,7 +3262,6 @@ class DiTAttention(nn.Module): self.heads = config.num_attention_heads self.inner_dim = config.head_dim * config.num_attention_heads self.dropout = config.dropout - self._attn_implementation = config._attn_implementation self.is_causal = False self.to_q = nn.Linear(config.hidden_size, self.inner_dim) @@ -3304,7 +3295,7 @@ class DiTAttention(nn.Module): cos, sin = position_embeddings query[:, :1], key[:, :1] = apply_rotary_pos_emb(query[:, :1], key[:, :1], cos, sin) - attention_interface = ALL_ATTENTION_FUNCTIONS[self._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attention_weights, _ = attention_interface( self, query, diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index 42f9609b5e..b4446d68f5 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -1149,6 +1149,8 @@ class SamModel(SamPreTrainedModel): self.vision_encoder = SamVisionEncoder(config.vision_config) self.prompt_encoder = SamPromptEncoder(config) + # The module using it is not a PreTrainedModel subclass so we need this + config.mask_decoder_config._attn_implementation = config._attn_implementation self.mask_decoder = SamMaskDecoder(config.mask_decoder_config) self.post_init() diff --git a/src/transformers/models/sam_hq/modeling_sam_hq.py b/src/transformers/models/sam_hq/modeling_sam_hq.py index 042e7ab7c0..6418e71720 100644 --- a/src/transformers/models/sam_hq/modeling_sam_hq.py +++ b/src/transformers/models/sam_hq/modeling_sam_hq.py @@ -1274,6 +1274,8 @@ class SamHQModel(SamHQPreTrainedModel): self.shared_image_embedding = SamHQPositionalEmbedding(config.vision_config) self.vision_encoder = SamHQVisionEncoder(config.vision_config) self.prompt_encoder = SamHQPromptEncoder(config) + # The module using it is not a PreTrainedModel subclass so we need this + config.mask_decoder_config._attn_implementation = config._attn_implementation self.mask_decoder = SamHQMaskDecoder(config.mask_decoder_config) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 472444220c..ba5741a904 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1836,7 +1836,7 @@ class T5EncoderModel(T5PreTrainedModel): super().__init__(config) self.shared = nn.Embedding(config.vocab_size, config.d_model) - encoder_config = copy.deepcopy(config) + encoder_config = config encoder_config.use_cache = False encoder_config.is_encoder_decoder = False self.encoder = T5Stack(encoder_config, self.shared) diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index eb7ee35f49..1bc780b3d0 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -14,7 +14,6 @@ # limitations under the License. """PyTorch TrOCR decoder model (based on RoBERTa).""" -import copy import math from typing import Optional, Union @@ -715,7 +714,6 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel, GenerationMixin): _tied_weights_keys = ["output_projection.weight"] def __init__(self, config): - config = copy.deepcopy(config) config.is_decoder = True config.is_encoder_decoder = False super().__init__(config) diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index 5a2916585d..f6c5c51a27 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -1173,6 +1173,10 @@ class XCLIPModel(XCLIPPreTrainedModel): text_config = config.text_config vision_config = config.vision_config + # The module using it is not a PreTrainedModel subclass so we need this + text_config._attn_implementation = config._attn_implementation + # The module using it is not a PreTrainedModel subclass so we need this + vision_config._attn_implementation = config._attn_implementation self.projection_dim = config.projection_dim self.text_embed_dim = text_config.hidden_size diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index 4d24fc6e70..f923273c02 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -1085,6 +1085,10 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) + @unittest.skip("T5 backbone deepcopies the configs, and fixing it would be more involved") + def test_internal_model_config_and_subconfig_are_same(self): + pass + class Blip2TextModelWithProjectionTester: def __init__(self, parent, vision_kwargs=None, qformer_kwargs=None, is_training=True): diff --git a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py index 8eeb65de77..26737a6473 100644 --- a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py +++ b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py @@ -542,6 +542,8 @@ class GPTBigCodeMQATest(unittest.TestCase): attn_pdrop=0, resid_pdrop=0, ) + # We need to set it here as it's normally set by the Model's __init__ + config._attn_implementation = "sdpa" return GPTBigCodeAttention(config) @parameterized.expand([(seed, is_train_mode) for seed in range(5) for is_train_mode in [True, False]]) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index fbb8d5f541..8539e83713 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4783,6 +4783,126 @@ class ModelTesterMixin: f"All parameters should be on meta device, but found {unique_devices}.", ) + def test_internal_model_config_and_subconfig_are_same(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + subconfig_keys = list(config.sub_configs.keys()) + for model_class in self.all_model_classes: + if len(config.sub_configs) == 0: + self.skipTest(reason="No subconfigs so the test does not make sense") + # Need to deepcopy here to avoid changing the _attn_implementation in-place + model = model_class(copy.deepcopy(config)) + + for submodule in model.modules(): + # This is a submodel + if isinstance(submodule, PreTrainedModel) and submodule.config.__class__ != model.config.__class__: + subconfig_from_model_internal = submodule.config + matching_sub_configs = [] + for subconfig_key in subconfig_keys: + # Get the subconfig from the model config + subconfig_from_model_config = getattr(model.config, subconfig_key) + if subconfig_from_model_config.__class__ == subconfig_from_model_internal.__class__: + # Since some composite models have different submodels parameterized by 2 of the same config + # class instances, we need to check against a list of matching classes, and check that at least + # 1 is the exact object (instead of checking immediately for similar object) + matching_sub_configs.append(subconfig_from_model_config) + + # Both should be exactly the same object, that is when instantiating the submodel when should + # absolutely not copy the subconfig + if len(matching_sub_configs) > 0: + self.assertTrue( + any( + subconfig_from_model_config is subconfig_from_model_internal + for subconfig_from_model_config in matching_sub_configs + ) + ) + + def test_can_set_attention_dynamically(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + if not model_class._can_set_attn_implementation(): + self.skipTest(reason="This model does not support setting its attention dynamically") + + # Need to deepcopy here to avoid changing the _attn_implementation in-place + model_config = copy.deepcopy(config) + # Set eager everywhere (it sets it recursively on subconfigs) + model_config._attn_implementation = "eager" + model = model_class(model_config) + + # sanity check to make sure everything is correctly eager + self.assertTrue(model.config._attn_implementation == "eager") + for subconfig_key in model.config.sub_configs: + self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "eager") + + if not all( + submodule._can_set_attn_implementation() + for submodule in model.modules() + if isinstance(submodule, PreTrainedModel) + ): + self.skipTest(reason="Parts of this model cannot set attention dynamically") + # Some old models technically should support switching, but don't have the flags active... + if not all( + submodule._supports_sdpa for submodule in model.modules() if isinstance(submodule, PreTrainedModel) + ): + self.skipTest(reason="Parts of this model don't support sdpa") + + # Now, set it to sdpa + model.set_attn_implementation("sdpa") + + # Check everything was correctly changed + self.assertTrue(model.config._attn_implementation == "sdpa") + for subconfig_key in model.config.sub_configs: + self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "sdpa") + + # Check we cannot set it to random values, and it raises a warning (but no crash) + with self.assertLogs("transformers.modeling_utils", level="WARNING") as cm: + model.set_attn_implementation("foo") + self.assertTrue( + any( + "Impossible to set the requested `attn_implementation`. The following error was captured:" + in warning + for warning in cm.output + ) + ) + + # Should still be sdpa everywhere + self.assertTrue(model.config._attn_implementation == "sdpa") + for subconfig_key in model.config.sub_configs: + self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "sdpa") + + def test_can_set_attention_dynamically_composite_model(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + if not model_class._can_set_attn_implementation(): + self.skipTest(reason="This model does not support setting its attention dynamically") + if not self._is_composite: + self.skipTest(reason="This model is not composite") + + # Need to deepcopy here to avoid changing the _attn_implementation in-place + model_config = copy.deepcopy(config) + # Set eager everywhere (it sets it recursively on subconfigs) + model_config._attn_implementation = "eager" + model = model_class(model_config) + + # sanity check to make sure everything is correctly eager + self.assertTrue(model.config._attn_implementation == "eager") + for subconfig_key in model.config.sub_configs: + self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "eager") + + if not all( + submodule._can_set_attn_implementation() + for submodule in model.modules() + if isinstance(submodule, PreTrainedModel) + ): + self.skipTest(reason="Parts of this model cannot set attention dynamically") + + # Now, set only top-most to sdpa (should support it if it supports the dynamic switch) + model.set_attn_implementation({"": "sdpa"}) + + # Check only top-most was correctly changed + self.assertTrue(model.config._attn_implementation == "sdpa") + for subconfig_key in model.config.sub_configs: + self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "eager") + global_rng = random.Random() diff --git a/tests/utils/test_configuration_utils.py b/tests/utils/test_configuration_utils.py index a34e9a5ea9..2bfd493993 100644 --- a/tests/utils/test_configuration_utils.py +++ b/tests/utils/test_configuration_utils.py @@ -194,7 +194,6 @@ class ConfigTestUtils(unittest.TestCase): "_name_or_path", "_commit_hash", "_attn_implementation_internal", - "_attn_implementation_autoset", "transformers_version", ], ) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 57d97ff214..b58629757e 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -83,12 +83,13 @@ from transformers.utils.import_utils import ( sys.path.append(str(Path(__file__).parent.parent.parent / "utils")) -from test_module.custom_configuration import CustomConfig, NoSuperInitConfig # noqa E402 +from test_module.custom_configuration import CustomConfig + if is_torch_available(): import torch from safetensors.torch import save_file as safe_save_file - from test_module.custom_modeling import CustomModel, NoSuperInitModel + from test_module.custom_modeling import CustomModel from torch import nn from transformers import ( @@ -732,36 +733,21 @@ class ModelUtilsTest(TestCasePlus): config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation) # Ensure the config was set correctly self.assertEqual(config._attn_implementation, requested_attn_implementation) - self.assertEqual(config._attn_implementation_internal, requested_attn_implementation) model = AutoModelForCausalLM.from_config(config) self.assertEqual(model.config._attn_implementation, requested_attn_implementation) config = AutoConfig.from_pretrained(TINY_MISTRAL) # When the config is not set, the default is "eager" - self.assertEqual(config._attn_implementation, "eager") - self.assertEqual(config._attn_implementation_internal, None) + self.assertEqual(config._attn_implementation, None) model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation) self.assertEqual(model.config._attn_implementation, requested_attn_implementation) # Set a nonsense attn_implementation in the config, which should be overridden by the explicit argument config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation="foo-bar-baz") self.assertEqual(config._attn_implementation, "foo-bar-baz") - self.assertEqual(config._attn_implementation_internal, "foo-bar-baz") model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation) self.assertEqual(model.config._attn_implementation, requested_attn_implementation) - def test_no_super_init_config_and_model(self): - config = NoSuperInitConfig(attribute=32) - model = NoSuperInitModel(config) - - with tempfile.TemporaryDirectory() as tmp_dir: - model.save_pretrained(tmp_dir) - - new_model = NoSuperInitModel.from_pretrained(tmp_dir) - - for p1, p2 in zip(model.parameters(), new_model.parameters()): - self.assertTrue(torch.equal(p1, p2)) - def test_checkpoint_sharding_local_bin(self): model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") diff --git a/utils/test_module/custom_configuration.py b/utils/test_module/custom_configuration.py index 676486fc51..4bb0fe6a15 100644 --- a/utils/test_module/custom_configuration.py +++ b/utils/test_module/custom_configuration.py @@ -7,10 +7,3 @@ class CustomConfig(PretrainedConfig): def __init__(self, attribute=1, **kwargs): self.attribute = attribute super().__init__(**kwargs) - - -class NoSuperInitConfig(PretrainedConfig): - model_type = "custom" - - def __init__(self, attribute=1, **kwargs): - self.attribute = attribute diff --git a/utils/test_module/custom_modeling.py b/utils/test_module/custom_modeling.py index 4b64b4a3df..fafa7bff25 100644 --- a/utils/test_module/custom_modeling.py +++ b/utils/test_module/custom_modeling.py @@ -2,7 +2,7 @@ import torch from transformers import PreTrainedModel -from .custom_configuration import CustomConfig, NoSuperInitConfig +from .custom_configuration import CustomConfig class CustomModel(PreTrainedModel): @@ -17,17 +17,3 @@ class CustomModel(PreTrainedModel): def _init_weights(self, module): pass - - -class NoSuperInitModel(PreTrainedModel): - config_class = NoSuperInitConfig - - def __init__(self, config): - super().__init__(config) - self.linear = torch.nn.Linear(config.attribute, config.attribute) - - def forward(self, x): - return self.linear(x) - - def _init_weights(self, module): - pass