Tpu trainer (#4146)
* wip * wip * a last wip * Better logging when using TPUs * Correct argument name * Tests * fix * Metrics in evaluation * Update src/transformers/training_args.py * [tpu] Use launcher script instead * [tpu] lots of tweaks * Fix formatting Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
@@ -11,6 +11,19 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
try:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
_has_tpu = True
|
||||
except ImportError:
|
||||
_has_tpu = False
|
||||
|
||||
|
||||
@torch_required
|
||||
def is_tpu_available():
|
||||
return _has_tpu
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -77,7 +90,7 @@ class TrainingArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
no_cuda: bool = field(default=False, metadata={"help": "Avoid using CUDA even if it is available"})
|
||||
no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"})
|
||||
seed: int = field(default=42, metadata={"help": "random seed for initialization"})
|
||||
|
||||
fp16: bool = field(
|
||||
@@ -95,6 +108,11 @@ class TrainingArguments:
|
||||
)
|
||||
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
|
||||
|
||||
tpu_num_cores: Optional[int] = field(
|
||||
default=None, metadata={"help": "TPU: Number of TPU cores (automatically passed by launcher script)"}
|
||||
)
|
||||
tpu_metrics_debug: bool = field(default=False, metadata={"help": "TPU: Whether to print debug metrics"})
|
||||
|
||||
@property
|
||||
def train_batch_size(self) -> int:
|
||||
return self.per_gpu_train_batch_size * max(1, self.n_gpu)
|
||||
@@ -110,6 +128,9 @@ class TrainingArguments:
|
||||
if self.no_cuda:
|
||||
device = torch.device("cpu")
|
||||
n_gpu = 0
|
||||
elif is_tpu_available():
|
||||
device = xm.xla_device()
|
||||
n_gpu = 0
|
||||
elif self.local_rank == -1:
|
||||
# if n_gpu is > 1 we'll use nn.DataParallel.
|
||||
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
|
||||
|
||||
Reference in New Issue
Block a user