[Utils] torch version checks optionally accept dev versions (#36847)

This commit is contained in:
Joao Gante
2025-03-25 10:58:58 +00:00
committed by GitHub
parent 80b4c5dcc9
commit bc1c90a755
4 changed files with 35 additions and 22 deletions

View File

@@ -611,21 +611,29 @@ class OffloadedCache(DynamicCache):
"""
def __init__(self) -> None:
if not (torch.cuda.is_available() or (is_torch_greater_or_equal("2.7") and torch.xpu.is_available())):
if not (
torch.cuda.is_available()
or (is_torch_greater_or_equal("2.7", accept_dev=True) and torch.xpu.is_available())
):
raise RuntimeError(
"OffloadedCache can only be used with a GPU" + (" or XPU" if is_torch_greater_or_equal("2.7") else "")
"OffloadedCache can only be used with a GPU"
+ (" or XPU" if is_torch_greater_or_equal("2.7", accept_dev=True) else "")
)
super().__init__()
self.original_device = []
self.prefetch_stream = None
self.prefetch_stream = torch.Stream() if is_torch_greater_or_equal("2.7") else torch.cuda.Stream()
self.prefetch_stream = (
torch.Stream() if is_torch_greater_or_equal("2.7", accept_dev=True) else torch.cuda.Stream()
)
self.beam_idx = None # used to delay beam search operations
def prefetch_layer(self, layer_idx: int):
"Starts prefetching the next layer cache"
if layer_idx < len(self):
with self.prefetch_stream if is_torch_greater_or_equal("2.7") else torch.cuda.stream(self.prefetch_stream):
with self.prefetch_stream if is_torch_greater_or_equal("2.7", accept_dev=True) else torch.cuda.stream(
self.prefetch_stream
):
# Prefetch next layer tensors to GPU
device = self.original_device[layer_idx]
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
@@ -643,7 +651,7 @@ class OffloadedCache(DynamicCache):
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
if layer_idx < len(self):
# Evict the previous layer if necessary
if is_torch_greater_or_equal("2.7"):
if is_torch_greater_or_equal("2.7", accept_dev=True):
torch.accelerator.current_stream().synchronize()
else:
torch.cuda.current_stream().synchronize()

View File

@@ -18,7 +18,6 @@ from functools import lru_cache, wraps
from typing import Callable
import torch
from packaging import version
from safetensors.torch import storage_ptr, storage_size
from torch import nn
@@ -29,18 +28,16 @@ ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
logger = logging.get_logger(__name__)
parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
is_torch_greater_or_equal_than_2_6 = parsed_torch_version_base >= version.parse("2.6")
is_torch_greater_or_equal_than_2_4 = parsed_torch_version_base >= version.parse("2.4")
is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3")
is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
is_torch_greater_or_equal_than_2_4 = is_torch_greater_or_equal("2.4", accept_dev=True)
is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True)
is_torch_greater_or_equal_than_2_2 = is_torch_greater_or_equal("2.2", accept_dev=True)
is_torch_greater_or_equal_than_2_1 = is_torch_greater_or_equal("2.1", accept_dev=True)
# For backwards compatibility (e.g. some remote codes on Hub using those variables).
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13")
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")
is_torch_greater_or_equal_than_2_0 = is_torch_greater_or_equal("2.0", accept_dev=True)
is_torch_greater_or_equal_than_1_13 = is_torch_greater_or_equal("1.13", accept_dev=True)
is_torch_greater_or_equal_than_1_12 = is_torch_greater_or_equal("1.12", accept_dev=True)
# Cache this result has it's a C FFI call which can be pretty time-consuming
_torch_distributed_available = torch.distributed.is_available()

View File

@@ -1070,13 +1070,21 @@ def is_flash_attn_greater_or_equal(library_version: str):
@lru_cache()
def is_torch_greater_or_equal(library_version: str):
def is_torch_greater_or_equal(library_version: str, accept_dev: bool = False):
"""
Accepts a library version and returns True if the current version of the library is greater than or equal to the
given version. If `accept_dev` is True, it will also accept development versions (e.g. 2.7.0.dev20250320 matches
2.7.0).
"""
if not _is_package_available("torch"):
return False
return version.parse(version.parse(importlib.metadata.version("torch")).base_version) >= version.parse(
library_version
)
if accept_dev:
return version.parse(version.parse(importlib.metadata.version("torch")).base_version) >= version.parse(
library_version
)
else:
return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version)
def is_torchdistx_available():

View File

@@ -605,7 +605,7 @@ class CacheIntegrationTest(unittest.TestCase):
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
device = model.device
if not is_torch_greater_or_equal("2.7") and device.type == "xpu":
if not is_torch_greater_or_equal("2.7", accept_dev=True) and device.type == "xpu":
self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")
input_text = "Fun fact:"
@@ -633,7 +633,7 @@ class CacheIntegrationTest(unittest.TestCase):
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
device = model.device
if not is_torch_greater_or_equal("2.7") and device.type == "xpu":
if not is_torch_greater_or_equal("2.7", accept_dev=True) and device.type == "xpu":
self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")
input_text = "Fun fact:"