[Utils] torch version checks optionally accept dev versions (#36847)

This commit is contained in:
Joao Gante
2025-03-25 10:58:58 +00:00
committed by GitHub
parent 80b4c5dcc9
commit bc1c90a755
4 changed files with 35 additions and 22 deletions

View File

@@ -611,21 +611,29 @@ class OffloadedCache(DynamicCache):
""" """
def __init__(self) -> None: def __init__(self) -> None:
if not (torch.cuda.is_available() or (is_torch_greater_or_equal("2.7") and torch.xpu.is_available())): if not (
torch.cuda.is_available()
or (is_torch_greater_or_equal("2.7", accept_dev=True) and torch.xpu.is_available())
):
raise RuntimeError( raise RuntimeError(
"OffloadedCache can only be used with a GPU" + (" or XPU" if is_torch_greater_or_equal("2.7") else "") "OffloadedCache can only be used with a GPU"
+ (" or XPU" if is_torch_greater_or_equal("2.7", accept_dev=True) else "")
) )
super().__init__() super().__init__()
self.original_device = [] self.original_device = []
self.prefetch_stream = None self.prefetch_stream = None
self.prefetch_stream = torch.Stream() if is_torch_greater_or_equal("2.7") else torch.cuda.Stream() self.prefetch_stream = (
torch.Stream() if is_torch_greater_or_equal("2.7", accept_dev=True) else torch.cuda.Stream()
)
self.beam_idx = None # used to delay beam search operations self.beam_idx = None # used to delay beam search operations
def prefetch_layer(self, layer_idx: int): def prefetch_layer(self, layer_idx: int):
"Starts prefetching the next layer cache" "Starts prefetching the next layer cache"
if layer_idx < len(self): if layer_idx < len(self):
with self.prefetch_stream if is_torch_greater_or_equal("2.7") else torch.cuda.stream(self.prefetch_stream): with self.prefetch_stream if is_torch_greater_or_equal("2.7", accept_dev=True) else torch.cuda.stream(
self.prefetch_stream
):
# Prefetch next layer tensors to GPU # Prefetch next layer tensors to GPU
device = self.original_device[layer_idx] device = self.original_device[layer_idx]
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True) self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
@@ -643,7 +651,7 @@ class OffloadedCache(DynamicCache):
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer." "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
if layer_idx < len(self): if layer_idx < len(self):
# Evict the previous layer if necessary # Evict the previous layer if necessary
if is_torch_greater_or_equal("2.7"): if is_torch_greater_or_equal("2.7", accept_dev=True):
torch.accelerator.current_stream().synchronize() torch.accelerator.current_stream().synchronize()
else: else:
torch.cuda.current_stream().synchronize() torch.cuda.current_stream().synchronize()

View File

@@ -18,7 +18,6 @@ from functools import lru_cache, wraps
from typing import Callable from typing import Callable
import torch import torch
from packaging import version
from safetensors.torch import storage_ptr, storage_size from safetensors.torch import storage_ptr, storage_size
from torch import nn from torch import nn
@@ -29,18 +28,16 @@ ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version) is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
is_torch_greater_or_equal_than_2_4 = is_torch_greater_or_equal("2.4", accept_dev=True)
is_torch_greater_or_equal_than_2_6 = parsed_torch_version_base >= version.parse("2.6") is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True)
is_torch_greater_or_equal_than_2_4 = parsed_torch_version_base >= version.parse("2.4") is_torch_greater_or_equal_than_2_2 = is_torch_greater_or_equal("2.2", accept_dev=True)
is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3") is_torch_greater_or_equal_than_2_1 = is_torch_greater_or_equal("2.1", accept_dev=True)
is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
# For backwards compatibility (e.g. some remote codes on Hub using those variables). # For backwards compatibility (e.g. some remote codes on Hub using those variables).
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0") is_torch_greater_or_equal_than_2_0 = is_torch_greater_or_equal("2.0", accept_dev=True)
is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13") is_torch_greater_or_equal_than_1_13 = is_torch_greater_or_equal("1.13", accept_dev=True)
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12") is_torch_greater_or_equal_than_1_12 = is_torch_greater_or_equal("1.12", accept_dev=True)
# Cache this result has it's a C FFI call which can be pretty time-consuming # Cache this result has it's a C FFI call which can be pretty time-consuming
_torch_distributed_available = torch.distributed.is_available() _torch_distributed_available = torch.distributed.is_available()

View File

@@ -1070,13 +1070,21 @@ def is_flash_attn_greater_or_equal(library_version: str):
@lru_cache() @lru_cache()
def is_torch_greater_or_equal(library_version: str): def is_torch_greater_or_equal(library_version: str, accept_dev: bool = False):
"""
Accepts a library version and returns True if the current version of the library is greater than or equal to the
given version. If `accept_dev` is True, it will also accept development versions (e.g. 2.7.0.dev20250320 matches
2.7.0).
"""
if not _is_package_available("torch"): if not _is_package_available("torch"):
return False return False
if accept_dev:
return version.parse(version.parse(importlib.metadata.version("torch")).base_version) >= version.parse( return version.parse(version.parse(importlib.metadata.version("torch")).base_version) >= version.parse(
library_version library_version
) )
else:
return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version)
def is_torchdistx_available(): def is_torchdistx_available():

View File

@@ -605,7 +605,7 @@ class CacheIntegrationTest(unittest.TestCase):
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16) model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
device = model.device device = model.device
if not is_torch_greater_or_equal("2.7") and device.type == "xpu": if not is_torch_greater_or_equal("2.7", accept_dev=True) and device.type == "xpu":
self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.") self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")
input_text = "Fun fact:" input_text = "Fun fact:"
@@ -633,7 +633,7 @@ class CacheIntegrationTest(unittest.TestCase):
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16) model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
device = model.device device = model.device
if not is_torch_greater_or_equal("2.7") and device.type == "xpu": if not is_torch_greater_or_equal("2.7", accept_dev=True) and device.type == "xpu":
self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.") self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")
input_text = "Fun fact:" input_text = "Fun fact:"