From 1791ef8df647a38b4fcb96c14ddd83a43861d713 Mon Sep 17 00:00:00 2001 From: Alex McKinney <44398246+vvvm23@users.noreply.github.com> Date: Thu, 17 Aug 2023 12:41:34 +0100 Subject: [PATCH] Adds `TRANSFORMERS_TEST_DEVICE` (#25506) * Adds `TRANSFORMERS_TEST_DEVICE` Mirrors the same API in the diffusers library. Useful in transformers too. * replace backend checking with trying `torch.device` * Adds better error message for unknown test devices * `make style` * adds documentation showing `TRANSFORMERS_TEST_DEVICE` usage. --- docs/source/en/testing.md | 10 ++++++++++ src/transformers/testing_utils.py | 11 ++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/docs/source/en/testing.md b/docs/source/en/testing.md index d3c512e8eb..0179d5c635 100644 --- a/docs/source/en/testing.md +++ b/docs/source/en/testing.md @@ -511,6 +511,16 @@ from transformers.testing_utils import get_gpu_count n_gpu = get_gpu_count() # works with torch and tf ``` +### Testing with a specific PyTorch backend + +To run the test suite on a specific torch backend add `TRANSFORMERS_TEST_DEVICE="$device"` where `$device` is the target backend. For example, to test on CPU only: +```bash +TRANSFORMERS_TEST_DEVICE="cpu" pytest tests/test_logging.py +``` + +This variable is useful for testing custom or less common PyTorch backends such as `mps`. It can also be used to achieve the same effect as `CUDA_VISIBLE_DEVICES` by targeting specific GPUs or testing in CPU-only mode. + + ### Distributed training `pytest` can't deal with distributed training directly. If this is attempted - the sub-processes don't do the right diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 94811a7bc8..18d5880a17 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -614,7 +614,16 @@ if is_torch_available(): # Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode import torch - if torch.cuda.is_available(): + if "TRANSFORMERS_TEST_DEVICE" in os.environ: + torch_device = os.environ["TRANSFORMERS_TEST_DEVICE"] + try: + # try creating device to see if provided device is valid + _ = torch.device(torch_device) + except RuntimeError as e: + raise RuntimeError( + f"Unknown testing device specified by environment variable `TRANSFORMERS_TEST_DEVICE`: {torch_device}" + ) from e + elif torch.cuda.is_available(): torch_device = "cuda" elif _run_third_party_device_tests and is_torch_npu_available(): torch_device = "npu"