[tests] make cuda-only tests device-agnostic (#35607)

* intial commit

* remove unrelated files

* further remove

* Update test_trainer.py

* fix style
This commit is contained in:
Fanli Lin
2025-01-13 21:48:39 +08:00
committed by GitHub
parent e6f9b03464
commit 2fa876d2d8
18 changed files with 57 additions and 47 deletions

View File

@@ -1862,7 +1862,6 @@ class ModelTesterMixin:
def test_resize_tokens_embeddings(self):
if not self.test_resize_embeddings:
self.skipTest(reason="test_resize_embeddings is set to `False`")
(
original_config,
inputs_dict,
@@ -2017,7 +2016,7 @@ class ModelTesterMixin:
torch.testing.assert_close(old_embeddings_mean, new_embeddings_mean, atol=1e-3, rtol=1e-1)
@require_deepspeed
@require_torch_gpu
@require_torch_accelerator
def test_resize_tokens_embeddings_with_deepspeed(self):
ds_config = {
"zero_optimization": {
@@ -2123,7 +2122,7 @@ class ModelTesterMixin:
model(**self._prepare_for_class(inputs_dict, model_class))
@require_deepspeed
@require_torch_gpu
@require_torch_accelerator
def test_resize_embeddings_untied_with_deepspeed(self):
ds_config = {
"zero_optimization": {
@@ -3202,7 +3201,7 @@ class ModelTesterMixin:
@require_accelerate
@mark.accelerate_tests
@require_torch_gpu
@require_torch_accelerator
def test_disk_offload_bin(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -3243,7 +3242,7 @@ class ModelTesterMixin:
@require_accelerate
@mark.accelerate_tests
@require_torch_gpu
@require_torch_accelerator
def test_disk_offload_safetensors(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -3278,7 +3277,7 @@ class ModelTesterMixin:
@require_accelerate
@mark.accelerate_tests
@require_torch_gpu
@require_torch_accelerator
def test_cpu_offload(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -4746,7 +4745,7 @@ class ModelTesterMixin:
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
@slow
@require_torch_gpu
@require_torch_accelerator
def test_torch_compile_for_training(self):
if version.parse(torch.__version__) < version.parse("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")