From ebf80e2e70ebe0054b28eb728c08591e5b488175 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Thu, 7 May 2020 10:34:04 -0400 Subject: [PATCH] Tpu trainer (#4146) * wip * wip * a last wip * Better logging when using TPUs * Correct argument name * Tests * fix * Metrics in evaluation * Update src/transformers/training_args.py * [tpu] Use launcher script instead * [tpu] lots of tweaks * Fix formatting Co-authored-by: Julien Chaumond --- examples/run_glue.py | 5 ++ examples/xla_spawn.py | 74 +++++++++++++++++++ src/transformers/trainer.py | 118 ++++++++++++++++++++++++------ src/transformers/training_args.py | 23 +++++- 4 files changed, 197 insertions(+), 23 deletions(-) create mode 100644 examples/xla_spawn.py diff --git a/examples/run_glue.py b/examples/run_glue.py index e58eb01211..fd568af107 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -202,5 +202,10 @@ def main(): return results +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + if __name__ == "__main__": main() diff --git a/examples/xla_spawn.py b/examples/xla_spawn.py new file mode 100644 index 0000000000..460e5d83a0 --- /dev/null +++ b/examples/xla_spawn.py @@ -0,0 +1,74 @@ +""" +A simple launcher script for TPU training + +Inspired by https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py + +:: + >>> python xla_spawn.py --num_cores=NUM_CORES_YOU_HAVE + YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other + arguments of your training script) + +""" + + +import importlib +import os +import sys +from argparse import REMAINDER, ArgumentParser + +import torch_xla.distributed.xla_multiprocessing as xmp + + +def trim_suffix(s: str, suffix: str): + return s if not s.endswith(suffix) or len(suffix) == 0 else s[: -len(suffix)] + + +def parse_args(): + """ + Helper function parsing the command line options + @retval ArgumentParser + """ + parser = ArgumentParser( + description=( + "PyTorch TPU distributed training launch " + "helper utility that will spawn up " + "multiple distributed processes" + ) + ) + + # Optional arguments for the launch helper + parser.add_argument("--num_cores", type=int, default=1, help="Number of TPU cores to use (1 or 8).") + + # positional + parser.add_argument( + "training_script", + type=str, + help=( + "The full module name to the single TPU training " + "program/script to be launched in parallel, " + "followed by all the arguments for the " + "training script" + ), + ) + + # rest from the training program + parser.add_argument("training_script_args", nargs=REMAINDER) + + return parser.parse_args() + + +def main(): + args = parse_args() + + # Import training_script as a module. + mod_name = trim_suffix(os.path.basename(args.training_script), ".py") + mod = importlib.import_module(mod_name) + + # Patch sys.argv + sys.argv = [args.training_script] + args.training_script_args + ["--tpu_num_cores", str(args.num_cores)] + + xmp.spawn(mod._mp_fn, args=(), nprocs=args.num_cores) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 2fcfe4531d..af54330776 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -21,7 +21,7 @@ from .data.data_collator import DataCollator, DefaultDataCollator from .modeling_utils import PreTrainedModel from .optimization import AdamW, get_linear_schedule_with_warmup from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput -from .training_args import TrainingArguments +from .training_args import TrainingArguments, is_tpu_available try: @@ -36,6 +36,11 @@ def is_apex_available(): return _has_apex +if is_tpu_available(): + import torch_xla.core.xla_model as xm + import torch_xla.debug.metrics as met + import torch_xla.distributed.parallel_loader as pl + try: from torch.utils.tensorboard import SummaryWriter @@ -88,6 +93,12 @@ def torch_distributed_zero_first(local_rank: int): torch.distributed.barrier() +def get_tpu_sampler(dataset: Dataset): + if xm.xrt_world_size() <= 1: + return RandomSampler(dataset) + return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) + + class Trainer: """ Trainer is a simple but feature-complete training and eval loop for PyTorch, @@ -146,41 +157,73 @@ class Trainer: ) set_seed(self.args.seed) # Create output directory if needed - if self.args.local_rank in [-1, 0]: + if self.is_local_master(): os.makedirs(self.args.output_dir, exist_ok=True) + if is_tpu_available(): + # Set an xla_device flag on the model's config. + # We'll find a more elegant and not need to do this in the future. + self.model.config.xla_device = True def get_train_dataloader(self) -> DataLoader: if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") - train_sampler = ( - RandomSampler(self.train_dataset) if self.args.local_rank == -1 else DistributedSampler(self.train_dataset) - ) - return DataLoader( + if is_tpu_available(): + train_sampler = get_tpu_sampler(self.train_dataset) + else: + train_sampler = ( + RandomSampler(self.train_dataset) + if self.args.local_rank == -1 + else DistributedSampler(self.train_dataset) + ) + + data_loader = DataLoader( self.train_dataset, batch_size=self.args.train_batch_size, sampler=train_sampler, collate_fn=self.data_collator.collate_batch, ) + if is_tpu_available(): + data_loader = pl.ParallelLoader(data_loader, [self.args.device]).per_device_loader(self.args.device) + + return data_loader + def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: if eval_dataset is None and self.eval_dataset is None: raise ValueError("Trainer: evaluation requires an eval_dataset.") - return DataLoader( + + sampler = get_tpu_sampler(eval_dataset) if is_tpu_available() else None + + data_loader = DataLoader( eval_dataset if eval_dataset is not None else self.eval_dataset, + sampler=sampler, batch_size=self.args.eval_batch_size, shuffle=False, collate_fn=self.data_collator.collate_batch, ) + if is_tpu_available(): + data_loader = pl.ParallelLoader(data_loader, [self.args.device]).per_device_loader(self.args.device) + + return data_loader + def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: # We use the same batch_size as for eval. - return DataLoader( + sampler = get_tpu_sampler(test_dataset) if is_tpu_available() else None + + data_loader = DataLoader( test_dataset, + sampler=sampler, batch_size=self.args.eval_batch_size, shuffle=False, collate_fn=self.data_collator.collate_batch, ) + if is_tpu_available(): + data_loader = pl.ParallelLoader(data_loader, [self.args.device]).per_device_loader(self.args.device) + + return data_loader + def get_optimizers( self, num_training_steps: int ) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]: @@ -222,7 +265,6 @@ class Trainer: If present, we will try reloading the optimizer/scheduler states from there. """ train_dataloader = self.get_train_dataloader() - if self.args.max_steps > 0: t_total = self.args.max_steps num_train_epochs = ( @@ -271,16 +313,21 @@ class Trainer: self._setup_wandb() # Train! + if is_tpu_available(): + num_examples = len(train_dataloader._loader._loader.dataset) + total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size() + else: + num_examples = len(train_dataloader.dataset) + total_train_batch_size = ( + self.args.train_batch_size + * self.args.gradient_accumulation_steps + * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1), + ) logger.info("***** Running training *****") - logger.info(" Num examples = %d", len(train_dataloader.dataset)) + logger.info(" Num examples = %d", num_examples) logger.info(" Num Epochs = %d", num_train_epochs) - logger.info(" Instantaneous batch size per GPU = %d", self.args.per_gpu_train_batch_size) - logger.info( - " Total train batch size (w. parallel, distributed & accumulation) = %d", - self.args.train_batch_size - * self.args.gradient_accumulation_steps - * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1), - ) + logger.info(" Instantaneous batch size per device = %d", self.args.per_gpu_train_batch_size) + logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size) logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) @@ -309,10 +356,10 @@ class Trainer: logging_loss = 0.0 model.zero_grad() train_iterator = trange( - epochs_trained, int(num_train_epochs), desc="Epoch", disable=self.args.local_rank not in [-1, 0], + epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_master() ) for epoch in train_iterator: - epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=self.args.local_rank not in [-1, 0]) + epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master()) for step, inputs in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training @@ -332,12 +379,16 @@ class Trainer: else: torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) - optimizer.step() + if is_tpu_available(): + xm.optimizer_step(optimizer) + else: + optimizer.step() + scheduler.step() model.zero_grad() global_step += 1 - if self.args.local_rank in [-1, 0]: + if self.is_local_master(): if (self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0) or ( global_step == 1 and self.args.logging_first_step ): @@ -371,6 +422,7 @@ class Trainer: assert model is self.model # Save model checkpoint output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{global_step}") + self.save_model(output_dir) self._rotate_checkpoints() torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) @@ -383,6 +435,9 @@ class Trainer: if self.args.max_steps > 0 and global_step > self.args.max_steps: train_iterator.close() break + if self.args.tpu_metrics_debug: + # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + xm.master_print(met.metrics_report()) if self.tb_writer: self.tb_writer.close() @@ -413,12 +468,21 @@ class Trainer: return loss.item() + def is_local_master(self) -> bool: + if is_tpu_available(): + return xm.is_master_ordinal(local=True) + else: + return self.args.local_rank in [-1, 0] + def is_world_master(self) -> bool: """ This will be True only in one process, even in distributed mode, even when training on multiple machines. """ - return self.args.local_rank == -1 or torch.distributed.get_rank() == 0 + if is_tpu_available(): + return xm.is_master_ordinal(local=False) + else: + return self.args.local_rank == -1 or torch.distributed.get_rank() == 0 def save_model(self, output_dir: Optional[str] = None): """ @@ -495,6 +559,11 @@ class Trainer: eval_dataloader = self.get_eval_dataloader(eval_dataset) output = self._prediction_loop(eval_dataloader, description="Evaluation") + + if self.args.tpu_metrics_debug: + # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + xm.master_print(met.metrics_report()) + return output.metrics def predict(self, test_dataset: Dataset) -> PredictionOutput: @@ -558,6 +627,11 @@ class Trainer: else: label_ids = np.append(label_ids, inputs["labels"].detach().cpu().numpy(), axis=0) + if is_tpu_available(): + # tpu-comment: Get all predictions and labels from all worker shards of eval dataset + preds = xm.mesh_reduce("eval_preds", preds, np.concatenate) + label_ids = xm.mesh_reduce("eval_out_label_ids", label_ids, np.concatenate) + if self.compute_metrics is not None and preds is not None and label_ids is not None: metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) else: diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 5bffd44b0d..067a74d191 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -11,6 +11,19 @@ if is_torch_available(): import torch +try: + import torch_xla.core.xla_model as xm + + _has_tpu = True +except ImportError: + _has_tpu = False + + +@torch_required +def is_tpu_available(): + return _has_tpu + + logger = logging.getLogger(__name__) @@ -77,7 +90,7 @@ class TrainingArguments: ) }, ) - no_cuda: bool = field(default=False, metadata={"help": "Avoid using CUDA even if it is available"}) + no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"}) seed: int = field(default=42, metadata={"help": "random seed for initialization"}) fp16: bool = field( @@ -95,6 +108,11 @@ class TrainingArguments: ) local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"}) + tpu_num_cores: Optional[int] = field( + default=None, metadata={"help": "TPU: Number of TPU cores (automatically passed by launcher script)"} + ) + tpu_metrics_debug: bool = field(default=False, metadata={"help": "TPU: Whether to print debug metrics"}) + @property def train_batch_size(self) -> int: return self.per_gpu_train_batch_size * max(1, self.n_gpu) @@ -110,6 +128,9 @@ class TrainingArguments: if self.no_cuda: device = torch.device("cpu") n_gpu = 0 + elif is_tpu_available(): + device = xm.xla_device() + n_gpu = 0 elif self.local_rank == -1: # if n_gpu is > 1 we'll use nn.DataParallel. # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`