add ascend npu accelerator support (#24879)
* Add Ascend NPU accelerator support * fix style warining
This commit is contained in:
@@ -47,6 +47,7 @@ from .utils import (
|
||||
is_torch_bf16_cpu_available,
|
||||
is_torch_bf16_gpu_available,
|
||||
is_torch_neuroncore_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_tf32_available,
|
||||
is_torch_tpu_available,
|
||||
logging,
|
||||
@@ -1368,12 +1369,13 @@ class TrainingArguments:
|
||||
self.framework == "pt"
|
||||
and is_torch_available()
|
||||
and (self.device.type != "cuda")
|
||||
and (self.device.type != "npu")
|
||||
and (get_xla_device_type(self.device) != "GPU")
|
||||
and (self.fp16 or self.fp16_full_eval)
|
||||
):
|
||||
raise ValueError(
|
||||
"FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation"
|
||||
" (`--fp16_full_eval`) can only be used on CUDA devices."
|
||||
" (`--fp16_full_eval`) can only be used on CUDA or NPU devices."
|
||||
)
|
||||
|
||||
if (
|
||||
@@ -1769,6 +1771,10 @@ class TrainingArguments:
|
||||
elif self.use_cpu:
|
||||
device = torch.device("cpu")
|
||||
self._n_gpu = 0
|
||||
elif is_torch_npu_available():
|
||||
device = torch.device("npu:0")
|
||||
torch.npu.set_device(device)
|
||||
self._n_gpu = 1
|
||||
else:
|
||||
# if n_gpu is > 1 we'll use nn.DataParallel.
|
||||
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
|
||||
|
||||
Reference in New Issue
Block a user