[tests] enable bnb tests on xpu (#36233)

* fix failed test

* fix device

* fix more device cases

* add more cases

* fix empty cache

* Update test_4bit.py

---------

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
Fanli Lin
2025-02-24 18:30:15 +08:00
committed by GitHub
parent 92c5ca9dd7
commit 4dbf17c17f
4 changed files with 27 additions and 21 deletions

View File

@@ -32,6 +32,7 @@ from transformers import (
from transformers.models.opt.modeling_opt import OPTAttention
from transformers.testing_utils import (
apply_skip_if_not_implemented,
backend_empty_cache,
is_bitsandbytes_available,
is_torch_available,
require_accelerate,
@@ -136,7 +137,7 @@ class Bnb4BitTest(Base4bitTest):
del self.model_4bit
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_quantization_num_parameters(self):
r"""
@@ -224,7 +225,7 @@ class Bnb4BitTest(Base4bitTest):
"""
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = self.model_4bit.generate(
input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10
input_ids=encoded_input["input_ids"].to(self.model_4bit.device), max_new_tokens=10
)
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
@@ -242,7 +243,7 @@ class Bnb4BitTest(Base4bitTest):
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model_4bit_from_config.generate(
input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10
input_ids=encoded_input["input_ids"].to(model_4bit_from_config.device), max_new_tokens=10
)
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
@@ -261,7 +262,7 @@ class Bnb4BitTest(Base4bitTest):
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model_4bit.generate(
input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10
input_ids=encoded_input["input_ids"].to(model_4bit.device), max_new_tokens=10
)
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
@@ -277,10 +278,10 @@ class Bnb4BitTest(Base4bitTest):
self.assertEqual(self.model_4bit.device.type, "cpu")
self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before)
if torch.cuda.is_available():
if torch_device in ["cuda", "xpu"]:
# Move back to CUDA device
self.model_4bit.to("cuda")
self.assertEqual(self.model_4bit.device.type, "cuda")
self.model_4bit.to(torch_device)
self.assertEqual(self.model_4bit.device.type, torch_device)
self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before)
def test_device_and_dtype_assignment(self):
@@ -323,11 +324,13 @@ class Bnb4BitTest(Base4bitTest):
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
self.model_fp16 = self.model_fp16.to(torch.float32)
_ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10)
_ = self.model_fp16.generate(
input_ids=encoded_input["input_ids"].to(self.model_fp16.device), max_new_tokens=10
)
if torch.cuda.is_available():
if torch_device in ["cuda", "xpu"]:
# Check that this does not throw an error
_ = self.model_fp16.cuda()
_ = self.model_fp16.to(torch_device)
# Check this does not throw an error
_ = self.model_fp16.to("cpu")
@@ -617,7 +620,7 @@ class BaseSerializationTest(unittest.TestCase):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_serialization(self, quant_type="nf4", double_quant=True, safe_serialization=True):
r"""