switch to device agnostic device calling for test cases (#38247)
* use device agnostic APIs in test cases Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> * add one more Signed-off-by: YAO Matrix <matrix.yao@intel.com> * xpu now supports integer device id, aligning to CUDA behaviors Signed-off-by: Matrix Yao <matrix.yao@intel.com> * update to use device_properties Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> * update comment Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix comments Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> --------- Signed-off-by: Matrix Yao <matrix.yao@intel.com> Signed-off-by: YAO Matrix <matrix.yao@intel.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -23,6 +23,7 @@ from transformers import (
|
||||
OPTForCausalLM,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
backend_empty_cache,
|
||||
require_accelerate,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
@@ -56,7 +57,6 @@ class BitNetQuantConfigTest(unittest.TestCase):
|
||||
@require_accelerate
|
||||
class BitNetTest(unittest.TestCase):
|
||||
model_name = "HF1BitLLM/Llama3-8B-1.58-100B-tokens"
|
||||
device = "cuda"
|
||||
|
||||
# called only once for all test in this class
|
||||
@classmethod
|
||||
@@ -65,11 +65,11 @@ class BitNetTest(unittest.TestCase):
|
||||
Load the model
|
||||
"""
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
|
||||
cls.quantized_model = AutoModelForCausalLM.from_pretrained(cls.model_name, device_map=cls.device)
|
||||
cls.quantized_model = AutoModelForCausalLM.from_pretrained(cls.model_name, device_map=torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
def test_replace_with_bitlinear(self):
|
||||
@@ -100,7 +100,7 @@ class BitNetTest(unittest.TestCase):
|
||||
"""
|
||||
input_text = "What are we having for dinner?"
|
||||
expected_output = "What are we having for dinner? What are we going to do for fun this weekend?"
|
||||
input_ids = self.tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
input_ids = self.tokenizer(input_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
output = self.quantized_model.generate(**input_ids, max_new_tokens=11, do_sample=False)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output)
|
||||
@@ -127,7 +127,7 @@ class BitNetTest(unittest.TestCase):
|
||||
from transformers.integrations import BitLinear
|
||||
|
||||
layer = BitLinear(in_features=4, out_features=2, bias=False, dtype=torch.float32)
|
||||
layer.to(self.device)
|
||||
layer.to(torch_device)
|
||||
|
||||
input_tensor = torch.tensor([1.0, -1.0, -1.0, 1.0], dtype=torch.float32).to(torch_device)
|
||||
|
||||
@@ -202,9 +202,8 @@ class BitNetTest(unittest.TestCase):
|
||||
class BitNetSerializationTest(unittest.TestCase):
|
||||
def test_model_serialization(self):
|
||||
model_name = "HF1BitLLM/Llama3-8B-1.58-100B-tokens"
|
||||
device = "cuda"
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device)
|
||||
input_tensor = torch.zeros((1, 8), dtype=torch.int32, device=device)
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=torch_device)
|
||||
input_tensor = torch.zeros((1, 8), dtype=torch.int32, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits_ref = quantized_model.forward(input_tensor).logits
|
||||
@@ -215,10 +214,10 @@ class BitNetSerializationTest(unittest.TestCase):
|
||||
|
||||
# Remove old model
|
||||
del quantized_model
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
# Load and check if the logits match
|
||||
model_loaded = AutoModelForCausalLM.from_pretrained("quant_model", device_map=device)
|
||||
model_loaded = AutoModelForCausalLM.from_pretrained("quant_model", device_map=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits_loaded = model_loaded.forward(input_tensor).logits
|
||||
|
||||
Reference in New Issue
Block a user