Adds TRANSFORMERS_TEST_DEVICE (#25506)
* Adds `TRANSFORMERS_TEST_DEVICE` Mirrors the same API in the diffusers library. Useful in transformers too. * replace backend checking with trying `torch.device` * Adds better error message for unknown test devices * `make style` * adds documentation showing `TRANSFORMERS_TEST_DEVICE` usage.
This commit is contained in:
@@ -614,7 +614,16 @@ if is_torch_available():
|
||||
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
if "TRANSFORMERS_TEST_DEVICE" in os.environ:
|
||||
torch_device = os.environ["TRANSFORMERS_TEST_DEVICE"]
|
||||
try:
|
||||
# try creating device to see if provided device is valid
|
||||
_ = torch.device(torch_device)
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(
|
||||
f"Unknown testing device specified by environment variable `TRANSFORMERS_TEST_DEVICE`: {torch_device}"
|
||||
) from e
|
||||
elif torch.cuda.is_available():
|
||||
torch_device = "cuda"
|
||||
elif _run_third_party_device_tests and is_torch_npu_available():
|
||||
torch_device = "npu"
|
||||
|
||||
Reference in New Issue
Block a user