From bc1c90a755505d576220ed2a161ca42dbf3cafaf Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 25 Mar 2025 10:58:58 +0000 Subject: [PATCH] [Utils] torch version checks optionally accept dev versions (#36847) --- src/transformers/cache_utils.py | 18 +++++++++++++----- src/transformers/pytorch_utils.py | 19 ++++++++----------- src/transformers/utils/import_utils.py | 16 ++++++++++++---- tests/utils/test_cache_utils.py | 4 ++-- 4 files changed, 35 insertions(+), 22 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index b8d8d9ca97..8fd8e96396 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -611,21 +611,29 @@ class OffloadedCache(DynamicCache): """ def __init__(self) -> None: - if not (torch.cuda.is_available() or (is_torch_greater_or_equal("2.7") and torch.xpu.is_available())): + if not ( + torch.cuda.is_available() + or (is_torch_greater_or_equal("2.7", accept_dev=True) and torch.xpu.is_available()) + ): raise RuntimeError( - "OffloadedCache can only be used with a GPU" + (" or XPU" if is_torch_greater_or_equal("2.7") else "") + "OffloadedCache can only be used with a GPU" + + (" or XPU" if is_torch_greater_or_equal("2.7", accept_dev=True) else "") ) super().__init__() self.original_device = [] self.prefetch_stream = None - self.prefetch_stream = torch.Stream() if is_torch_greater_or_equal("2.7") else torch.cuda.Stream() + self.prefetch_stream = ( + torch.Stream() if is_torch_greater_or_equal("2.7", accept_dev=True) else torch.cuda.Stream() + ) self.beam_idx = None # used to delay beam search operations def prefetch_layer(self, layer_idx: int): "Starts prefetching the next layer cache" if layer_idx < len(self): - with self.prefetch_stream if is_torch_greater_or_equal("2.7") else torch.cuda.stream(self.prefetch_stream): + with self.prefetch_stream if is_torch_greater_or_equal("2.7", accept_dev=True) else torch.cuda.stream( + self.prefetch_stream + ): # Prefetch next layer tensors to GPU device = self.original_device[layer_idx] self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True) @@ -643,7 +651,7 @@ class OffloadedCache(DynamicCache): "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer." if layer_idx < len(self): # Evict the previous layer if necessary - if is_torch_greater_or_equal("2.7"): + if is_torch_greater_or_equal("2.7", accept_dev=True): torch.accelerator.current_stream().synchronize() else: torch.cuda.current_stream().synchronize() diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index c899490824..4082fba798 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -18,7 +18,6 @@ from functools import lru_cache, wraps from typing import Callable import torch -from packaging import version from safetensors.torch import storage_ptr, storage_size from torch import nn @@ -29,18 +28,16 @@ ALL_LAYERNORM_LAYERS = [nn.LayerNorm] logger = logging.get_logger(__name__) -parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version) - -is_torch_greater_or_equal_than_2_6 = parsed_torch_version_base >= version.parse("2.6") -is_torch_greater_or_equal_than_2_4 = parsed_torch_version_base >= version.parse("2.4") -is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3") -is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2") -is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1") +is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True) +is_torch_greater_or_equal_than_2_4 = is_torch_greater_or_equal("2.4", accept_dev=True) +is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True) +is_torch_greater_or_equal_than_2_2 = is_torch_greater_or_equal("2.2", accept_dev=True) +is_torch_greater_or_equal_than_2_1 = is_torch_greater_or_equal("2.1", accept_dev=True) # For backwards compatibility (e.g. some remote codes on Hub using those variables). -is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0") -is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13") -is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12") +is_torch_greater_or_equal_than_2_0 = is_torch_greater_or_equal("2.0", accept_dev=True) +is_torch_greater_or_equal_than_1_13 = is_torch_greater_or_equal("1.13", accept_dev=True) +is_torch_greater_or_equal_than_1_12 = is_torch_greater_or_equal("1.12", accept_dev=True) # Cache this result has it's a C FFI call which can be pretty time-consuming _torch_distributed_available = torch.distributed.is_available() diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index b6eb2be5db..3897080516 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -1070,13 +1070,21 @@ def is_flash_attn_greater_or_equal(library_version: str): @lru_cache() -def is_torch_greater_or_equal(library_version: str): +def is_torch_greater_or_equal(library_version: str, accept_dev: bool = False): + """ + Accepts a library version and returns True if the current version of the library is greater than or equal to the + given version. If `accept_dev` is True, it will also accept development versions (e.g. 2.7.0.dev20250320 matches + 2.7.0). + """ if not _is_package_available("torch"): return False - return version.parse(version.parse(importlib.metadata.version("torch")).base_version) >= version.parse( - library_version - ) + if accept_dev: + return version.parse(version.parse(importlib.metadata.version("torch")).base_version) >= version.parse( + library_version + ) + else: + return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version) def is_torchdistx_available(): diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index fd5b720d90..816632ea53 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -605,7 +605,7 @@ class CacheIntegrationTest(unittest.TestCase): model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16) device = model.device - if not is_torch_greater_or_equal("2.7") and device.type == "xpu": + if not is_torch_greater_or_equal("2.7", accept_dev=True) and device.type == "xpu": self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.") input_text = "Fun fact:" @@ -633,7 +633,7 @@ class CacheIntegrationTest(unittest.TestCase): model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16) device = model.device - if not is_torch_greater_or_equal("2.7") and device.type == "xpu": + if not is_torch_greater_or_equal("2.7", accept_dev=True) and device.type == "xpu": self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.") input_text = "Fun fact:"