Fix all is_torch_tpu_available issues (#17936)
* Fix all is_torch_tpu_available
This commit is contained in:
@@ -20,7 +20,7 @@ from transformers import Trainer, is_torch_tpu_available
|
||||
from transformers.trainer_utils import PredictionOutput
|
||||
|
||||
|
||||
if is_torch_tpu_available():
|
||||
if is_torch_tpu_available(check_device=False):
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.debug.metrics as met
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from transformers import Seq2SeqTrainer, is_torch_tpu_available
|
||||
from transformers.trainer_utils import PredictionOutput
|
||||
|
||||
|
||||
if is_torch_tpu_available():
|
||||
if is_torch_tpu_available(check_device=False):
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.debug.metrics as met
|
||||
|
||||
|
||||
Reference in New Issue
Block a user