Make torch xla available on GPU (#29334)
* add USE_TORCH_XLA env * rename torch_tpu to torch_xla * better is_torch_xla_available; fix some fsdp and performance issues * fix format * fix bug when pjrt_device is cpu * fix bug * fix the deprecation handling --------- Co-authored-by: anw90 <ang868@gmail.com> Co-authored-by: wangang.wa <wangang.wa@alibaba-inc.com>
This commit is contained in:
@@ -24,13 +24,13 @@ import quant_trainer
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from transformers import Trainer, is_torch_tpu_available
|
||||
from transformers import Trainer, is_torch_xla_available
|
||||
from transformers.trainer_utils import PredictionOutput
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if is_torch_tpu_available(check_device=False):
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.debug.metrics as met
|
||||
|
||||
|
||||
Reference in New Issue
Block a user