[tests] ensure device-required software is available in the testing environment before testing (#29477)
* gix * fix style * add warning * revert * no newline * revert * revert * add CUDA as well
This commit is contained in:
@@ -808,6 +808,19 @@ if is_torch_available():
|
||||
|
||||
if "TRANSFORMERS_TEST_DEVICE" in os.environ:
|
||||
torch_device = os.environ["TRANSFORMERS_TEST_DEVICE"]
|
||||
if torch_device == "cuda" and not torch.cuda.is_available():
|
||||
raise ValueError(
|
||||
f"TRANSFORMERS_TEST_DEVICE={torch_device}, but CUDA is unavailable. Please double-check your testing environment."
|
||||
)
|
||||
if torch_device == "xpu" and not is_torch_xpu_available():
|
||||
raise ValueError(
|
||||
f"TRANSFORMERS_TEST_DEVICE={torch_device}, but XPU is unavailable. Please double-check your testing environment."
|
||||
)
|
||||
if torch_device == "npu" and not is_torch_npu_available():
|
||||
raise ValueError(
|
||||
f"TRANSFORMERS_TEST_DEVICE={torch_device}, but NPU is unavailable. Please double-check your testing environment."
|
||||
)
|
||||
|
||||
try:
|
||||
# try creating device to see if provided device is valid
|
||||
_ = torch.device(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user