tests: revert change of torch_require_multi_gpu to be device agnostic (#35721)

* tests: revert change of torch_require_multi_gpu to be device agnostic

The 11c27dd33 modified `torch_require_multi_gpu()` to be device agnostic
instead of being CUDA specific. This broke some tests which are rightfully
CUDA specific, such as:

* `tests/trainer/test_trainer_distributed.py::TestTrainerDistributed`

In the current Transformers tests architecture `require_torch_multi_accelerator()`
should be used to mark multi-GPU tests agnostic to device.

This change addresses the issue introduced by 11c27dd33 and reverts
modification of `torch_require_multi_gpu()`.

Fixes: 11c27dd33 ("Enable BNB multi-backend support (#31098)")
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>

* fix bug: modification of frozen set

---------

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
Co-authored-by: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com>
Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
Dmitry Rogozhkin
2025-02-25 04:36:10 -08:00
committed by GitHub
parent d80d52b007
commit b4b9da6d9b
4 changed files with 10 additions and 21 deletions

View File

@@ -39,7 +39,7 @@ from transformers.testing_utils import (
require_bitsandbytes,
require_torch,
require_torch_gpu_if_bnb_not_multi_backend_enabled,
require_torch_multi_gpu,
require_torch_multi_accelerator,
slow,
torch_device,
)
@@ -671,7 +671,7 @@ class MixedInt8TestPipeline(BaseMixedInt8Test):
self.assertIn(pipeline_output[0]["generated_text"], self.EXPECTED_OUTPUTS)
@require_torch_multi_gpu
@require_torch_multi_accelerator
@apply_skip_if_not_implemented
class MixedInt8TestMultiGpu(BaseMixedInt8Test):
def setUp(self):
@@ -700,7 +700,7 @@ class MixedInt8TestMultiGpu(BaseMixedInt8Test):
self.assertIn(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
@require_torch_multi_gpu
@require_torch_multi_accelerator
@apply_skip_if_not_implemented
class MixedInt8TestCpuGpu(BaseMixedInt8Test):
def setUp(self):