Adds TRANSFORMERS_TEST_DEVICE (#25506)

* Adds `TRANSFORMERS_TEST_DEVICE`
Mirrors the same API in the diffusers library. Useful in transformers
too.

* replace backend checking with trying `torch.device`

* Adds better error message for unknown test devices

* `make style`

* adds documentation showing `TRANSFORMERS_TEST_DEVICE` usage.
This commit is contained in:
Alex McKinney
2023-08-17 12:41:34 +01:00
committed by GitHub
parent e7e9261a20
commit 1791ef8df6
2 changed files with 20 additions and 1 deletions

View File

@@ -614,7 +614,16 @@ if is_torch_available():
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
import torch
if torch.cuda.is_available():
if "TRANSFORMERS_TEST_DEVICE" in os.environ:
torch_device = os.environ["TRANSFORMERS_TEST_DEVICE"]
try:
# try creating device to see if provided device is valid
_ = torch.device(torch_device)
except RuntimeError as e:
raise RuntimeError(
f"Unknown testing device specified by environment variable `TRANSFORMERS_TEST_DEVICE`: {torch_device}"
) from e
elif torch.cuda.is_available():
torch_device = "cuda"
elif _run_third_party_device_tests and is_torch_npu_available():
torch_device = "npu"