[tests] enable bnb tests on xpu (#36233)
* fix failed test * fix device * fix more device cases * add more cases * fix empty cache * Update test_4bit.py --------- Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -32,6 +32,7 @@ from transformers import (
|
||||
from transformers.models.opt.modeling_opt import OPTAttention
|
||||
from transformers.testing_utils import (
|
||||
apply_skip_if_not_implemented,
|
||||
backend_empty_cache,
|
||||
is_bitsandbytes_available,
|
||||
is_torch_available,
|
||||
require_accelerate,
|
||||
@@ -136,7 +137,7 @@ class Bnb4BitTest(Base4bitTest):
|
||||
del self.model_4bit
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_quantization_num_parameters(self):
|
||||
r"""
|
||||
@@ -224,7 +225,7 @@ class Bnb4BitTest(Base4bitTest):
|
||||
"""
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
output_sequences = self.model_4bit.generate(
|
||||
input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10
|
||||
input_ids=encoded_input["input_ids"].to(self.model_4bit.device), max_new_tokens=10
|
||||
)
|
||||
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
@@ -242,7 +243,7 @@ class Bnb4BitTest(Base4bitTest):
|
||||
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
output_sequences = model_4bit_from_config.generate(
|
||||
input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10
|
||||
input_ids=encoded_input["input_ids"].to(model_4bit_from_config.device), max_new_tokens=10
|
||||
)
|
||||
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
@@ -261,7 +262,7 @@ class Bnb4BitTest(Base4bitTest):
|
||||
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
output_sequences = model_4bit.generate(
|
||||
input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10
|
||||
input_ids=encoded_input["input_ids"].to(model_4bit.device), max_new_tokens=10
|
||||
)
|
||||
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
@@ -277,10 +278,10 @@ class Bnb4BitTest(Base4bitTest):
|
||||
self.assertEqual(self.model_4bit.device.type, "cpu")
|
||||
self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
if torch_device in ["cuda", "xpu"]:
|
||||
# Move back to CUDA device
|
||||
self.model_4bit.to("cuda")
|
||||
self.assertEqual(self.model_4bit.device.type, "cuda")
|
||||
self.model_4bit.to(torch_device)
|
||||
self.assertEqual(self.model_4bit.device.type, torch_device)
|
||||
self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before)
|
||||
|
||||
def test_device_and_dtype_assignment(self):
|
||||
@@ -323,11 +324,13 @@ class Bnb4BitTest(Base4bitTest):
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
|
||||
self.model_fp16 = self.model_fp16.to(torch.float32)
|
||||
_ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10)
|
||||
_ = self.model_fp16.generate(
|
||||
input_ids=encoded_input["input_ids"].to(self.model_fp16.device), max_new_tokens=10
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
if torch_device in ["cuda", "xpu"]:
|
||||
# Check that this does not throw an error
|
||||
_ = self.model_fp16.cuda()
|
||||
_ = self.model_fp16.to(torch_device)
|
||||
|
||||
# Check this does not throw an error
|
||||
_ = self.model_fp16.to("cpu")
|
||||
@@ -617,7 +620,7 @@ class BaseSerializationTest(unittest.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_serialization(self, quant_type="nf4", double_quant=True, safe_serialization=True):
|
||||
r"""
|
||||
|
||||
@@ -274,7 +274,7 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
"""
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
output_sequences = self.model_8bit.generate(
|
||||
input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10
|
||||
input_ids=encoded_input["input_ids"].to(self.model_8bit.device), max_new_tokens=10
|
||||
)
|
||||
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
@@ -292,7 +292,7 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
output_sequences = model_8bit_from_config.generate(
|
||||
input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10
|
||||
input_ids=encoded_input["input_ids"].to(model_8bit_from_config.device), max_new_tokens=10
|
||||
)
|
||||
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
@@ -311,7 +311,7 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
output_sequences = model_8bit.generate(
|
||||
input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10
|
||||
input_ids=encoded_input["input_ids"].to(model_8bit.device), max_new_tokens=10
|
||||
)
|
||||
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
@@ -362,7 +362,9 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
|
||||
self.model_fp16 = self.model_fp16.to(torch.float32)
|
||||
_ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10)
|
||||
_ = self.model_fp16.generate(
|
||||
input_ids=encoded_input["input_ids"].to(self.model_fp16.device), max_new_tokens=10
|
||||
)
|
||||
|
||||
# Check this does not throw an error
|
||||
_ = self.model_fp16.to("cpu")
|
||||
@@ -402,7 +404,7 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
# generate
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
output_sequences = model_from_saved.generate(
|
||||
input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10
|
||||
input_ids=encoded_input["input_ids"].to(model_from_saved.device), max_new_tokens=10
|
||||
)
|
||||
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
@@ -429,7 +431,7 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
# generate
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
output_sequences = model_from_saved.generate(
|
||||
input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10
|
||||
input_ids=encoded_input["input_ids"].to(model_from_saved.device), max_new_tokens=10
|
||||
)
|
||||
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
|
||||
Reference in New Issue
Block a user