[tests] ensure device-required software is available in the testing environment before testing (#29477)
* gix * fix style * add warning * revert * no newline * revert * revert * add CUDA as well
This commit is contained in:
@@ -808,6 +808,19 @@ if is_torch_available():
|
|||||||
|
|
||||||
if "TRANSFORMERS_TEST_DEVICE" in os.environ:
|
if "TRANSFORMERS_TEST_DEVICE" in os.environ:
|
||||||
torch_device = os.environ["TRANSFORMERS_TEST_DEVICE"]
|
torch_device = os.environ["TRANSFORMERS_TEST_DEVICE"]
|
||||||
|
if torch_device == "cuda" and not torch.cuda.is_available():
|
||||||
|
raise ValueError(
|
||||||
|
f"TRANSFORMERS_TEST_DEVICE={torch_device}, but CUDA is unavailable. Please double-check your testing environment."
|
||||||
|
)
|
||||||
|
if torch_device == "xpu" and not is_torch_xpu_available():
|
||||||
|
raise ValueError(
|
||||||
|
f"TRANSFORMERS_TEST_DEVICE={torch_device}, but XPU is unavailable. Please double-check your testing environment."
|
||||||
|
)
|
||||||
|
if torch_device == "npu" and not is_torch_npu_available():
|
||||||
|
raise ValueError(
|
||||||
|
f"TRANSFORMERS_TEST_DEVICE={torch_device}, but NPU is unavailable. Please double-check your testing environment."
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# try creating device to see if provided device is valid
|
# try creating device to see if provided device is valid
|
||||||
_ = torch.device(torch_device)
|
_ = torch.device(torch_device)
|
||||||
|
|||||||
Reference in New Issue
Block a user