[Utils] torch version checks optionally accept dev versions (#36847)
This commit is contained in:
@@ -611,21 +611,29 @@ class OffloadedCache(DynamicCache):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
if not (torch.cuda.is_available() or (is_torch_greater_or_equal("2.7") and torch.xpu.is_available())):
|
if not (
|
||||||
|
torch.cuda.is_available()
|
||||||
|
or (is_torch_greater_or_equal("2.7", accept_dev=True) and torch.xpu.is_available())
|
||||||
|
):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"OffloadedCache can only be used with a GPU" + (" or XPU" if is_torch_greater_or_equal("2.7") else "")
|
"OffloadedCache can only be used with a GPU"
|
||||||
|
+ (" or XPU" if is_torch_greater_or_equal("2.7", accept_dev=True) else "")
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.original_device = []
|
self.original_device = []
|
||||||
self.prefetch_stream = None
|
self.prefetch_stream = None
|
||||||
self.prefetch_stream = torch.Stream() if is_torch_greater_or_equal("2.7") else torch.cuda.Stream()
|
self.prefetch_stream = (
|
||||||
|
torch.Stream() if is_torch_greater_or_equal("2.7", accept_dev=True) else torch.cuda.Stream()
|
||||||
|
)
|
||||||
self.beam_idx = None # used to delay beam search operations
|
self.beam_idx = None # used to delay beam search operations
|
||||||
|
|
||||||
def prefetch_layer(self, layer_idx: int):
|
def prefetch_layer(self, layer_idx: int):
|
||||||
"Starts prefetching the next layer cache"
|
"Starts prefetching the next layer cache"
|
||||||
if layer_idx < len(self):
|
if layer_idx < len(self):
|
||||||
with self.prefetch_stream if is_torch_greater_or_equal("2.7") else torch.cuda.stream(self.prefetch_stream):
|
with self.prefetch_stream if is_torch_greater_or_equal("2.7", accept_dev=True) else torch.cuda.stream(
|
||||||
|
self.prefetch_stream
|
||||||
|
):
|
||||||
# Prefetch next layer tensors to GPU
|
# Prefetch next layer tensors to GPU
|
||||||
device = self.original_device[layer_idx]
|
device = self.original_device[layer_idx]
|
||||||
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
|
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
|
||||||
@@ -643,7 +651,7 @@ class OffloadedCache(DynamicCache):
|
|||||||
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
|
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
|
||||||
if layer_idx < len(self):
|
if layer_idx < len(self):
|
||||||
# Evict the previous layer if necessary
|
# Evict the previous layer if necessary
|
||||||
if is_torch_greater_or_equal("2.7"):
|
if is_torch_greater_or_equal("2.7", accept_dev=True):
|
||||||
torch.accelerator.current_stream().synchronize()
|
torch.accelerator.current_stream().synchronize()
|
||||||
else:
|
else:
|
||||||
torch.cuda.current_stream().synchronize()
|
torch.cuda.current_stream().synchronize()
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from functools import lru_cache, wraps
|
|||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
|
||||||
from safetensors.torch import storage_ptr, storage_size
|
from safetensors.torch import storage_ptr, storage_size
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
@@ -29,18 +28,16 @@ ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
|
is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
|
||||||
|
is_torch_greater_or_equal_than_2_4 = is_torch_greater_or_equal("2.4", accept_dev=True)
|
||||||
is_torch_greater_or_equal_than_2_6 = parsed_torch_version_base >= version.parse("2.6")
|
is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True)
|
||||||
is_torch_greater_or_equal_than_2_4 = parsed_torch_version_base >= version.parse("2.4")
|
is_torch_greater_or_equal_than_2_2 = is_torch_greater_or_equal("2.2", accept_dev=True)
|
||||||
is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3")
|
is_torch_greater_or_equal_than_2_1 = is_torch_greater_or_equal("2.1", accept_dev=True)
|
||||||
is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
|
|
||||||
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
|
|
||||||
|
|
||||||
# For backwards compatibility (e.g. some remote codes on Hub using those variables).
|
# For backwards compatibility (e.g. some remote codes on Hub using those variables).
|
||||||
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
|
is_torch_greater_or_equal_than_2_0 = is_torch_greater_or_equal("2.0", accept_dev=True)
|
||||||
is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13")
|
is_torch_greater_or_equal_than_1_13 = is_torch_greater_or_equal("1.13", accept_dev=True)
|
||||||
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")
|
is_torch_greater_or_equal_than_1_12 = is_torch_greater_or_equal("1.12", accept_dev=True)
|
||||||
|
|
||||||
# Cache this result has it's a C FFI call which can be pretty time-consuming
|
# Cache this result has it's a C FFI call which can be pretty time-consuming
|
||||||
_torch_distributed_available = torch.distributed.is_available()
|
_torch_distributed_available = torch.distributed.is_available()
|
||||||
|
|||||||
@@ -1070,13 +1070,21 @@ def is_flash_attn_greater_or_equal(library_version: str):
|
|||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def is_torch_greater_or_equal(library_version: str):
|
def is_torch_greater_or_equal(library_version: str, accept_dev: bool = False):
|
||||||
|
"""
|
||||||
|
Accepts a library version and returns True if the current version of the library is greater than or equal to the
|
||||||
|
given version. If `accept_dev` is True, it will also accept development versions (e.g. 2.7.0.dev20250320 matches
|
||||||
|
2.7.0).
|
||||||
|
"""
|
||||||
if not _is_package_available("torch"):
|
if not _is_package_available("torch"):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if accept_dev:
|
||||||
return version.parse(version.parse(importlib.metadata.version("torch")).base_version) >= version.parse(
|
return version.parse(version.parse(importlib.metadata.version("torch")).base_version) >= version.parse(
|
||||||
library_version
|
library_version
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version)
|
||||||
|
|
||||||
|
|
||||||
def is_torchdistx_available():
|
def is_torchdistx_available():
|
||||||
|
|||||||
@@ -605,7 +605,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
|||||||
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
|
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
|
||||||
device = model.device
|
device = model.device
|
||||||
|
|
||||||
if not is_torch_greater_or_equal("2.7") and device.type == "xpu":
|
if not is_torch_greater_or_equal("2.7", accept_dev=True) and device.type == "xpu":
|
||||||
self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")
|
self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")
|
||||||
|
|
||||||
input_text = "Fun fact:"
|
input_text = "Fun fact:"
|
||||||
@@ -633,7 +633,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
|||||||
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
|
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
|
||||||
device = model.device
|
device = model.device
|
||||||
|
|
||||||
if not is_torch_greater_or_equal("2.7") and device.type == "xpu":
|
if not is_torch_greater_or_equal("2.7", accept_dev=True) and device.type == "xpu":
|
||||||
self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")
|
self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")
|
||||||
|
|
||||||
input_text = "Fun fact:"
|
input_text = "Fun fact:"
|
||||||
|
|||||||
Reference in New Issue
Block a user