[tests] make cuda-only tests device-agnostic (#35607)
* intial commit * remove unrelated files * further remove * Update test_trainer.py * fix style
This commit is contained in:
@@ -1862,7 +1862,6 @@ class ModelTesterMixin:
|
||||
def test_resize_tokens_embeddings(self):
|
||||
if not self.test_resize_embeddings:
|
||||
self.skipTest(reason="test_resize_embeddings is set to `False`")
|
||||
|
||||
(
|
||||
original_config,
|
||||
inputs_dict,
|
||||
@@ -2017,7 +2016,7 @@ class ModelTesterMixin:
|
||||
torch.testing.assert_close(old_embeddings_mean, new_embeddings_mean, atol=1e-3, rtol=1e-1)
|
||||
|
||||
@require_deepspeed
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_resize_tokens_embeddings_with_deepspeed(self):
|
||||
ds_config = {
|
||||
"zero_optimization": {
|
||||
@@ -2123,7 +2122,7 @@ class ModelTesterMixin:
|
||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
@require_deepspeed
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_resize_embeddings_untied_with_deepspeed(self):
|
||||
ds_config = {
|
||||
"zero_optimization": {
|
||||
@@ -3202,7 +3201,7 @@ class ModelTesterMixin:
|
||||
|
||||
@require_accelerate
|
||||
@mark.accelerate_tests
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_disk_offload_bin(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -3243,7 +3242,7 @@ class ModelTesterMixin:
|
||||
|
||||
@require_accelerate
|
||||
@mark.accelerate_tests
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_disk_offload_safetensors(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -3278,7 +3277,7 @@ class ModelTesterMixin:
|
||||
|
||||
@require_accelerate
|
||||
@mark.accelerate_tests
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_cpu_offload(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -4746,7 +4745,7 @@ class ModelTesterMixin:
|
||||
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_torch_compile_for_training(self):
|
||||
if version.parse(torch.__version__) < version.parse("2.3"):
|
||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||
|
||||
Reference in New Issue
Block a user