[tests] enable bnb tests on xpu (#36233)

* fix failed test

* fix device

* fix more device cases

* add more cases

* fix empty cache

* Update test_4bit.py

---------

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
Fanli Lin
2025-02-24 18:30:15 +08:00
committed by GitHub
parent 92c5ca9dd7
commit 4dbf17c17f
4 changed files with 27 additions and 21 deletions

View File

@@ -274,7 +274,7 @@ class MixedInt8Test(BaseMixedInt8Test):
"""
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = self.model_8bit.generate(
input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10
input_ids=encoded_input["input_ids"].to(self.model_8bit.device), max_new_tokens=10
)
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
@@ -292,7 +292,7 @@ class MixedInt8Test(BaseMixedInt8Test):
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model_8bit_from_config.generate(
input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10
input_ids=encoded_input["input_ids"].to(model_8bit_from_config.device), max_new_tokens=10
)
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
@@ -311,7 +311,7 @@ class MixedInt8Test(BaseMixedInt8Test):
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model_8bit.generate(
input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10
input_ids=encoded_input["input_ids"].to(model_8bit.device), max_new_tokens=10
)
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
@@ -362,7 +362,9 @@ class MixedInt8Test(BaseMixedInt8Test):
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
self.model_fp16 = self.model_fp16.to(torch.float32)
_ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10)
_ = self.model_fp16.generate(
input_ids=encoded_input["input_ids"].to(self.model_fp16.device), max_new_tokens=10
)
# Check this does not throw an error
_ = self.model_fp16.to("cpu")
@@ -402,7 +404,7 @@ class MixedInt8Test(BaseMixedInt8Test):
# generate
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model_from_saved.generate(
input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10
input_ids=encoded_input["input_ids"].to(model_from_saved.device), max_new_tokens=10
)
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
@@ -429,7 +431,7 @@ class MixedInt8Test(BaseMixedInt8Test):
# generate
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model_from_saved.generate(
input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10
input_ids=encoded_input["input_ids"].to(model_from_saved.device), max_new_tokens=10
)
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)