switch to device agnostic device calling for test cases (#38247)
* use device agnostic APIs in test cases Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> * add one more Signed-off-by: YAO Matrix <matrix.yao@intel.com> * xpu now supports integer device id, aligning to CUDA behaviors Signed-off-by: Matrix Yao <matrix.yao@intel.com> * update to use device_properties Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> * update comment Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix comments Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> --------- Signed-off-by: Matrix Yao <matrix.yao@intel.com> Signed-off-by: YAO Matrix <matrix.yao@intel.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -31,6 +31,8 @@ from transformers import (
|
||||
from transformers.models.opt.modeling_opt import OPTAttention
|
||||
from transformers.testing_utils import (
|
||||
apply_skip_if_not_implemented,
|
||||
backend_empty_cache,
|
||||
backend_torch_accelerator_module,
|
||||
is_accelerate_available,
|
||||
is_bitsandbytes_available,
|
||||
is_torch_available,
|
||||
@@ -137,7 +139,7 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
del self.model_8bit
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_get_keys_to_not_convert(self):
|
||||
r"""
|
||||
@@ -484,7 +486,7 @@ class MixedInt8T5Test(unittest.TestCase):
|
||||
avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
|
||||
"""
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_inference_without_keep_in_fp32(self):
|
||||
r"""
|
||||
@@ -599,7 +601,7 @@ class MixedInt8ModelClassesTest(BaseMixedInt8Test):
|
||||
del self.seq_to_seq_model
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_correct_head_class(self):
|
||||
r"""
|
||||
@@ -631,7 +633,7 @@ class MixedInt8TestPipeline(BaseMixedInt8Test):
|
||||
del self.pipe
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_pipeline(self):
|
||||
r"""
|
||||
@@ -872,10 +874,10 @@ class MixedInt8TestTraining(BaseMixedInt8Test):
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True)
|
||||
model.train()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.assertEqual(set(model.hf_device_map.values()), {torch.cuda.current_device()})
|
||||
elif torch.xpu.is_available():
|
||||
self.assertEqual(set(model.hf_device_map.values()), {f"xpu:{torch.xpu.current_device()}"})
|
||||
if torch_device in ["cuda", "xpu"]:
|
||||
self.assertEqual(
|
||||
set(model.hf_device_map.values()), {backend_torch_accelerator_module(torch_device).current_device()}
|
||||
)
|
||||
else:
|
||||
self.assertTrue(all(param.device.type == "cpu" for param in model.parameters()))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user