From b11050d6a288e47438c2f8986bfa57aa1d5c364a Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Wed, 19 Mar 2025 22:15:52 +0800 Subject: [PATCH] enable OffloadedCache on XPU from PyTorch 2.7 (#36654) * fix "Cannot copy out of meta tensor; no data!" issue for BartForConditionalGeneration model * follow Marc's suggestion to use _tie_weights to fix Signed-off-by: Yao, Matrix * enable OffloadedCache on XPU since PyTorch 2.7 Signed-off-by: Yao, Matrix * fix style Signed-off-by: Yao, Matrix * don't change bart Signed-off-by: root * make code more concise per review comments Signed-off-by: N * fix review comments Signed-off-by: root * Revert "fix review comments" This reverts commit acf1484b86c7cc58b2dee69e7008c0eeb4c97b1b. * fix review comments Signed-off-by: root * fix style Signed-off-by: root --------- Signed-off-by: Yao, Matrix Signed-off-by: root Signed-off-by: N Co-authored-by: root Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/cache_utils.py | 23 +++++++++++------ src/transformers/utils/import_utils.py | 4 ++- tests/utils/test_cache_utils.py | 35 +++++++++++++++++++------- 3 files changed, 44 insertions(+), 18 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 11c25b2827..558bcfb2e2 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -9,7 +9,7 @@ import torch from packaging import version from .configuration_utils import PretrainedConfig -from .utils import is_hqq_available, is_optimum_quanto_available, logging +from .utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal, logging if is_hqq_available(): @@ -537,10 +537,10 @@ class DynamicCache(Cache): class OffloadedCache(DynamicCache): """ - A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory. + A drop-in replacement for DynamicCache that conserves accelerator(GPU, XPU) memory at the expense of more CPU memory. Useful for generating from models with very long context. - In addition to the default CUDA stream, where all forward() computations happen, + In addition to the default accelerator stream, where all forward() computations happen, this class uses another stream, the prefetch stream, which it creates itself. Since scheduling of operations on separate streams happens independently, this class uses the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing. @@ -549,17 +549,21 @@ class OffloadedCache(DynamicCache): """ def __init__(self) -> None: - if not torch.cuda.is_available(): - raise RuntimeError("OffloadedCache can only be used with a GPU") + if not (torch.cuda.is_available() or (is_torch_greater_or_equal("2.7") and torch.xpu.is_available())): + raise RuntimeError( + "OffloadedCache can only be used with a GPU" + (" or XPU" if is_torch_greater_or_equal("2.7") else "") + ) + super().__init__() self.original_device = [] - self.prefetch_stream = torch.cuda.Stream() + self.prefetch_stream = None + self.prefetch_stream = torch.Stream() if is_torch_greater_or_equal("2.7") else torch.cuda.Stream() self.beam_idx = None # used to delay beam search operations def prefetch_layer(self, layer_idx: int): "Starts prefetching the next layer cache" if layer_idx < len(self): - with torch.cuda.stream(self.prefetch_stream): + with self.prefetch_stream if is_torch_greater_or_equal("2.7") else torch.cuda.stream(self.prefetch_stream): # Prefetch next layer tensors to GPU device = self.original_device[layer_idx] self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True) @@ -577,7 +581,10 @@ class OffloadedCache(DynamicCache): "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer." if layer_idx < len(self): # Evict the previous layer if necessary - torch.cuda.current_stream().synchronize() + if is_torch_greater_or_equal("2.7"): + torch.accelerator.current_stream().synchronize() + else: + torch.cuda.current_stream().synchronize() self.evict_previous_layer(layer_idx) # Load current layer cache to its original device if not already there original_device = self.original_device[layer_idx] diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 4f2724f1c8..ad4e685b24 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -1062,7 +1062,9 @@ def is_torch_greater_or_equal(library_version: str): if not _is_package_available("torch"): return False - return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version) + return version.parse(version.parse(importlib.metadata.version("torch")).base_version) >= version.parse( + library_version + ) def is_torchdistx_available(): diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index fc7617e649..efe4e6af5c 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -27,6 +27,7 @@ from transformers.testing_utils import ( require_non_xpu, require_read_token, require_torch, + require_torch_accelerator, require_torch_gpu, require_torch_multi_gpu, slow, @@ -48,7 +49,7 @@ if is_torch_available(): StaticCache, convert_and_export_with_cache, ) - from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3 + from transformers.utils import is_torch_greater_or_equal @require_torch @@ -179,7 +180,7 @@ class CacheTest(unittest.TestCase): """ Tests that static cache works with `torch.export()` """ - if not is_torch_greater_or_equal_than_2_3: + if not is_torch_greater_or_equal("2.3"): self.skipTest(reason="This test requires torch >= 2.3 to run.") set_seed(0) @@ -230,7 +231,7 @@ class CacheTest(unittest.TestCase): self.assertEqual(n_static_value_caches, model.config.num_hidden_layers) -@require_torch_gpu +@require_torch_accelerator @slow class CacheIntegrationTest(unittest.TestCase): def test_dynamic_cache_hard(self): @@ -542,13 +543,17 @@ class CacheIntegrationTest(unittest.TestCase): def test_static_cache_beam_search(self): pass - @require_torch_gpu + @require_torch_accelerator def test_offloaded_cache_equivalent_to_dynamic_cache(self): """Tests that OffloadedCache produces the same result as the default DynamicCache""" model_name = "microsoft/Phi-3-mini-4k-instruct" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16) device = model.device + + if not is_torch_greater_or_equal("2.7") and device.type == "xpu": + self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.") + input_text = "Fun fact:" inputs = tokenizer(input_text, return_tensors="pt").to(device) common = { @@ -566,13 +571,17 @@ class CacheIntegrationTest(unittest.TestCase): for original_output, offloaded_output in zip(original_outputs, offloaded_outputs): assert torch.all(original_output == offloaded_output).item() - @require_torch_gpu + @require_torch_accelerator def test_offloaded_cache_uses_less_memory_than_dynamic_cache(self): """Tests that OffloadedCache uses less memory than the default DynamicCache""" model_name = "microsoft/Phi-3-mini-4k-instruct" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16) device = model.device + + if not is_torch_greater_or_equal("2.7") and device.type == "xpu": + self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.") + input_text = "Fun fact:" inputs = tokenizer(input_text, return_tensors="pt").to(device) common = { @@ -585,12 +594,20 @@ class CacheIntegrationTest(unittest.TestCase): } original = GenerationConfig(**common) offloaded = GenerationConfig(cache_implementation="offloaded", **common) - torch.cuda.reset_peak_memory_stats(device) + + torch_accelerator_module = None + if device.type == "cuda": + torch_accelerator_module = torch.cuda + elif device.type == "xpu": + torch_accelerator_module = torch.xpu + + torch_accelerator_module.reset_peak_memory_stats(device) model.generate(generation_config=original, **inputs) - original_peak_memory = torch.cuda.max_memory_allocated(device) - torch.cuda.reset_peak_memory_stats(device) + original_peak_memory = torch_accelerator_module.max_memory_allocated(device) + torch_accelerator_module.reset_peak_memory_stats(device) model.generate(generation_config=offloaded, **inputs) - offloaded_peak_memory = torch.cuda.max_memory_allocated(device) + offloaded_peak_memory = torch_accelerator_module.max_memory_allocated(device) + print(f"original_peak_memory: {original_peak_memory}, offloaded_peak_memory: {offloaded_peak_memory}") assert offloaded_peak_memory < original_peak_memory @require_torch_gpu