[tests] enable bnb tests on xpu (#36233)
* fix failed test * fix device * fix more device cases * add more cases * fix empty cache * Update test_4bit.py --------- Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -274,7 +274,7 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
"""
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
output_sequences = self.model_8bit.generate(
|
||||
input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10
|
||||
input_ids=encoded_input["input_ids"].to(self.model_8bit.device), max_new_tokens=10
|
||||
)
|
||||
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
@@ -292,7 +292,7 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
output_sequences = model_8bit_from_config.generate(
|
||||
input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10
|
||||
input_ids=encoded_input["input_ids"].to(model_8bit_from_config.device), max_new_tokens=10
|
||||
)
|
||||
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
@@ -311,7 +311,7 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
output_sequences = model_8bit.generate(
|
||||
input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10
|
||||
input_ids=encoded_input["input_ids"].to(model_8bit.device), max_new_tokens=10
|
||||
)
|
||||
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
@@ -362,7 +362,9 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
|
||||
self.model_fp16 = self.model_fp16.to(torch.float32)
|
||||
_ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10)
|
||||
_ = self.model_fp16.generate(
|
||||
input_ids=encoded_input["input_ids"].to(self.model_fp16.device), max_new_tokens=10
|
||||
)
|
||||
|
||||
# Check this does not throw an error
|
||||
_ = self.model_fp16.to("cpu")
|
||||
@@ -402,7 +404,7 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
# generate
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
output_sequences = model_from_saved.generate(
|
||||
input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10
|
||||
input_ids=encoded_input["input_ids"].to(model_from_saved.device), max_new_tokens=10
|
||||
)
|
||||
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
@@ -429,7 +431,7 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
# generate
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
output_sequences = model_from_saved.generate(
|
||||
input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10
|
||||
input_ids=encoded_input["input_ids"].to(model_from_saved.device), max_new_tokens=10
|
||||
)
|
||||
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
|
||||
Reference in New Issue
Block a user