From 9f0402bc4db85f24626b5e6cb8b766244e2abc48 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 26 May 2025 12:11:54 +0200 Subject: [PATCH] Fix all import errors based on older torch versions (#38370) * Update masking_utils.py * fix * fix * fix * Update masking_utils.py * Update executorch.py * fix --- src/transformers/masking_utils.py | 33 +++++++++++++++++------------- src/transformers/modeling_utils.py | 2 +- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 8829c0711a..c42a1917ce 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -25,11 +25,16 @@ from .utils.import_utils import is_torch_flex_attn_available, is_torch_greater_o if is_torch_flex_attn_available(): - from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex from torch.nn.attention.flex_attention import BlockMask, create_block_mask - +else: + # Register a fake type to avoid crashing for annotations and `isinstance` checks + BlockMask = torch.Tensor _is_torch_greater_or_equal_than_2_5 = is_torch_greater_or_equal("2.5", accept_dev=True) +_is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True) + +if _is_torch_greater_or_equal_than_2_6: + from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex def and_masks(*mask_functions: list[Callable]) -> Callable: @@ -415,14 +420,14 @@ def sdpa_mask_older_torch( # Due to a bug in versions of torch<2.5, we need to update the mask in case a query is not attending to any # tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213 - if allow_torch_fix: + if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix: causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True) return causal_mask # We use the version with newer torch whenever possible, as it is more general and can handle arbitrary mask functions # (especially mask_function indexing a tensor, such as the padding mask function) -sdpa_mask = sdpa_mask_recent_torch if is_torch_flex_attn_available() else sdpa_mask_older_torch +sdpa_mask = sdpa_mask_recent_torch if _is_torch_greater_or_equal_than_2_6 else sdpa_mask_older_torch def eager_mask( @@ -522,7 +527,7 @@ def flex_attention_mask( mask_function: Callable = causal_mask_function, attention_mask: Optional[torch.Tensor] = None, **kwargs, -) -> "BlockMask": +) -> BlockMask: """ Create a 4D block mask which is a compressed representation of the full 4D block causal mask. BlockMask is essential for performant computation of flex attention. See: https://pytorch.org/blog/flexattention/ @@ -652,7 +657,7 @@ def create_causal_mask( past_key_values: Optional[Cache], or_mask_function: Optional[Callable] = None, and_mask_function: Optional[Callable] = None, -) -> Optional[Union[torch.Tensor, "BlockMask"]]: +) -> Optional[Union[torch.Tensor, BlockMask]]: """ Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values` has an HybridCache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align @@ -700,12 +705,12 @@ def create_causal_mask( # Allow slight deviations from causal mask if or_mask_function is not None: - if not _is_torch_greater_or_equal_than_2_5: + if not _is_torch_greater_or_equal_than_2_6: raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.5") mask_factory_function = or_masks(mask_factory_function, or_mask_function) allow_is_causal_skip = False if and_mask_function is not None: - if not _is_torch_greater_or_equal_than_2_5: + if not _is_torch_greater_or_equal_than_2_6: raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.5") mask_factory_function = and_masks(mask_factory_function, and_mask_function) allow_is_causal_skip = False @@ -733,7 +738,7 @@ def create_sliding_window_causal_mask( past_key_values: Optional[Cache], or_mask_function: Optional[Callable] = None, and_mask_function: Optional[Callable] = None, -) -> Optional[Union[torch.Tensor, "BlockMask"]]: +) -> Optional[Union[torch.Tensor, BlockMask]]: """ Create a sliding window causal mask based on the attention implementation used (stored in the config). This type of attention pattern was mostly democratized by Mistral. If `past_key_values` has an HybridCache structure, this @@ -786,12 +791,12 @@ def create_sliding_window_causal_mask( # Allow slight deviations from sliding causal mask if or_mask_function is not None: - if not _is_torch_greater_or_equal_than_2_5: + if not _is_torch_greater_or_equal_than_2_6: raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.5") mask_factory_function = or_masks(mask_factory_function, or_mask_function) allow_is_causal_skip = False if and_mask_function is not None: - if not _is_torch_greater_or_equal_than_2_5: + if not _is_torch_greater_or_equal_than_2_6: raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.5") mask_factory_function = and_masks(mask_factory_function, and_mask_function) allow_is_causal_skip = False @@ -820,7 +825,7 @@ def create_chunked_causal_mask( past_key_values: Optional[Cache], or_mask_function: Optional[Callable] = None, and_mask_function: Optional[Callable] = None, -) -> Optional[Union[torch.Tensor, "BlockMask"]]: +) -> Optional[Union[torch.Tensor, BlockMask]]: """ Create a chunked attention causal mask based on the attention implementation used (stored in the config). This type of attention pattern was mostly democratized by Llama4. If `past_key_values` has an HybridCache structure, this @@ -880,12 +885,12 @@ def create_chunked_causal_mask( # Allow slight deviations from chunked causal mask if or_mask_function is not None: - if not _is_torch_greater_or_equal_than_2_5: + if not _is_torch_greater_or_equal_than_2_6: raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.5") mask_factory_function = or_masks(mask_factory_function, or_mask_function) allow_is_causal_skip = False if and_mask_function is not None: - if not _is_torch_greater_or_equal_than_2_5: + if not _is_torch_greater_or_equal_than_2_6: raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.5") mask_factory_function = and_masks(mask_factory_function, and_mask_function) allow_is_causal_skip = False diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 58d422b37b..f426d10ccd 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2078,7 +2078,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi if plan := getattr(module, "_tp_plan", None): self._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()}) - if self._tp_plan is not None and is_torch_greater_or_equal("2.3"): + if self._tp_plan is not None and is_torch_greater_or_equal("2.5"): for _, v in self._tp_plan.items(): if v not in ALL_PARALLEL_STYLES: raise ValueError(