enable OffloadedCache on XPU from PyTorch 2.7 (#36654)

* fix "Cannot copy out of meta tensor; no data!" issue for BartForConditionalGeneration model

* follow Marc's suggestion to use _tie_weights to fix

Signed-off-by: Yao, Matrix <matrix.yao@intel.com>

* enable OffloadedCache on XPU since PyTorch 2.7

Signed-off-by: Yao, Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: Yao, Matrix <matrix.yao@intel.com>

* don't change bart

Signed-off-by: root <root@a4bf01945cfe.jf.intel.com>

* make code more concise per review comments

Signed-off-by: N <matrix.yao@intel.com>

* fix review comments

Signed-off-by: root <root@a4bf01945cfe.jf.intel.com>

* Revert "fix review comments"

This reverts commit acf1484b86c7cc58b2dee69e7008c0eeb4c97b1b.

* fix review comments

Signed-off-by: root <root@a4bf01945cfe.jf.intel.com>

* fix style

Signed-off-by: root <root@a4bf01945cfe.jf.intel.com>

---------

Signed-off-by: Yao, Matrix <matrix.yao@intel.com>
Signed-off-by: root <root@a4bf01945cfe.jf.intel.com>
Signed-off-by: N <matrix.yao@intel.com>
Co-authored-by: root <root@a4bf01945cfe.jf.intel.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Yao Matrix
2025-03-19 22:15:52 +08:00
committed by GitHub
parent e8d960329e
commit b11050d6a2
3 changed files with 44 additions and 18 deletions

View File

@@ -27,6 +27,7 @@ from transformers.testing_utils import (
require_non_xpu,
require_read_token,
require_torch,
require_torch_accelerator,
require_torch_gpu,
require_torch_multi_gpu,
slow,
@@ -48,7 +49,7 @@ if is_torch_available():
StaticCache,
convert_and_export_with_cache,
)
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
from transformers.utils import is_torch_greater_or_equal
@require_torch
@@ -179,7 +180,7 @@ class CacheTest(unittest.TestCase):
"""
Tests that static cache works with `torch.export()`
"""
if not is_torch_greater_or_equal_than_2_3:
if not is_torch_greater_or_equal("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")
set_seed(0)
@@ -230,7 +231,7 @@ class CacheTest(unittest.TestCase):
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)
@require_torch_gpu
@require_torch_accelerator
@slow
class CacheIntegrationTest(unittest.TestCase):
def test_dynamic_cache_hard(self):
@@ -542,13 +543,17 @@ class CacheIntegrationTest(unittest.TestCase):
def test_static_cache_beam_search(self):
pass
@require_torch_gpu
@require_torch_accelerator
def test_offloaded_cache_equivalent_to_dynamic_cache(self):
"""Tests that OffloadedCache produces the same result as the default DynamicCache"""
model_name = "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
device = model.device
if not is_torch_greater_or_equal("2.7") and device.type == "xpu":
self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")
input_text = "Fun fact:"
inputs = tokenizer(input_text, return_tensors="pt").to(device)
common = {
@@ -566,13 +571,17 @@ class CacheIntegrationTest(unittest.TestCase):
for original_output, offloaded_output in zip(original_outputs, offloaded_outputs):
assert torch.all(original_output == offloaded_output).item()
@require_torch_gpu
@require_torch_accelerator
def test_offloaded_cache_uses_less_memory_than_dynamic_cache(self):
"""Tests that OffloadedCache uses less memory than the default DynamicCache"""
model_name = "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
device = model.device
if not is_torch_greater_or_equal("2.7") and device.type == "xpu":
self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")
input_text = "Fun fact:"
inputs = tokenizer(input_text, return_tensors="pt").to(device)
common = {
@@ -585,12 +594,20 @@ class CacheIntegrationTest(unittest.TestCase):
}
original = GenerationConfig(**common)
offloaded = GenerationConfig(cache_implementation="offloaded", **common)
torch.cuda.reset_peak_memory_stats(device)
torch_accelerator_module = None
if device.type == "cuda":
torch_accelerator_module = torch.cuda
elif device.type == "xpu":
torch_accelerator_module = torch.xpu
torch_accelerator_module.reset_peak_memory_stats(device)
model.generate(generation_config=original, **inputs)
original_peak_memory = torch.cuda.max_memory_allocated(device)
torch.cuda.reset_peak_memory_stats(device)
original_peak_memory = torch_accelerator_module.max_memory_allocated(device)
torch_accelerator_module.reset_peak_memory_stats(device)
model.generate(generation_config=offloaded, **inputs)
offloaded_peak_memory = torch.cuda.max_memory_allocated(device)
offloaded_peak_memory = torch_accelerator_module.max_memory_allocated(device)
print(f"original_peak_memory: {original_peak_memory}, offloaded_peak_memory: {offloaded_peak_memory}")
assert offloaded_peak_memory < original_peak_memory
@require_torch_gpu