From 272f48e734300cf6df66b5fcdaf462e47de3ccd6 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Fri, 15 Mar 2024 22:28:52 +0800 Subject: [PATCH] [tests] ensure device-required software is available in the testing environment before testing (#29477) * gix * fix style * add warning * revert * no newline * revert * revert * add CUDA as well --- src/transformers/testing_utils.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 038946b322..e357aaf9f1 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -808,6 +808,19 @@ if is_torch_available(): if "TRANSFORMERS_TEST_DEVICE" in os.environ: torch_device = os.environ["TRANSFORMERS_TEST_DEVICE"] + if torch_device == "cuda" and not torch.cuda.is_available(): + raise ValueError( + f"TRANSFORMERS_TEST_DEVICE={torch_device}, but CUDA is unavailable. Please double-check your testing environment." + ) + if torch_device == "xpu" and not is_torch_xpu_available(): + raise ValueError( + f"TRANSFORMERS_TEST_DEVICE={torch_device}, but XPU is unavailable. Please double-check your testing environment." + ) + if torch_device == "npu" and not is_torch_npu_available(): + raise ValueError( + f"TRANSFORMERS_TEST_DEVICE={torch_device}, but NPU is unavailable. Please double-check your testing environment." + ) + try: # try creating device to see if provided device is valid _ = torch.device(torch_device)