From 26ce4dd8b79dce59e183a8aeefe20e7e98a49113 Mon Sep 17 00:00:00 2001 From: Alan Ji Date: Tue, 8 Aug 2023 19:48:50 +0800 Subject: [PATCH] Enable tests to run on third-party devcies (#25327) * enable unit tests to run on third-party devcies other than CUDA and CPU. * remove the modification that enabled ut on MPS * control test on third-party device by env variable * update --------- Co-authored-by: statelesshz --- src/transformers/testing_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index a9ab304d2a..b93a40daa2 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -175,6 +175,7 @@ _run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False) _tf_gpu_memory_limit = parse_int_from_env("TF_GPU_MEMORY_LIMIT", default=None) _run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True) _run_tool_tests = parse_flag_from_env("RUN_TOOL_TESTS", default=False) +_run_third_party_device_tests = parse_flag_from_env("RUN_THIRD_PARTY_DEVICE_TESTS", default=False) def is_pt_tf_cross_test(test_case): @@ -612,7 +613,12 @@ if is_torch_available(): # Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode import torch - torch_device = "cuda" if torch.cuda.is_available() else "cpu" + if torch.cuda.is_available(): + torch_device = "cuda" + elif _run_third_party_device_tests and is_torch_npu_available(): + torch_device = "npu" + else: + torch_device = "cpu" else: torch_device = None