From 9d2ce253deec94ed55ae576847d5798f0e7defa1 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Thu, 21 May 2020 09:18:27 -0400 Subject: [PATCH] TPU hangs when saving optimizer/scheduler (#4467) * TPU hangs when saving optimizer/scheduler * Style * ParallelLoader is not a DataLoader * Style * Addressing @julien-c's comments --- src/transformers/trainer.py | 67 ++++++++++++++++++------------------- 1 file changed, 33 insertions(+), 34 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 1105a6009f..e4b14750d0 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -242,9 +242,6 @@ class Trainer: collate_fn=self.data_collator.collate_batch, ) - if is_tpu_available(): - data_loader = pl.ParallelLoader(data_loader, [self.args.device]).per_device_loader(self.args.device) - return data_loader def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: @@ -269,9 +266,6 @@ class Trainer: collate_fn=self.data_collator.collate_batch, ) - if is_tpu_available(): - data_loader = pl.ParallelLoader(data_loader, [self.args.device]).per_device_loader(self.args.device) - return data_loader def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: @@ -292,9 +286,6 @@ class Trainer: collate_fn=self.data_collator.collate_batch, ) - if is_tpu_available(): - data_loader = pl.ParallelLoader(data_loader, [self.args.device]).per_device_loader(self.args.device) - return data_loader def get_optimizers( @@ -351,15 +342,11 @@ class Trainer: self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps) ) - def num_examples(self, dataloader: Union[DataLoader, "pl.PerDeviceLoader"]) -> int: + def num_examples(self, dataloader: DataLoader) -> int: """ Helper to get num of examples from a DataLoader, by accessing its Dataset. """ - if is_tpu_available(): - assert isinstance(dataloader, pl.PerDeviceLoader) - return len(dataloader._loader._loader.dataset) - else: - return len(dataloader.dataset) + return len(dataloader.dataset) def train(self, model_path: Optional[str] = None): """ @@ -466,7 +453,14 @@ class Trainer: if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch) - epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master()) + if is_tpu_available(): + parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader( + self.args.device + ) + epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_master()) + else: + epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master()) + for step, inputs in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training @@ -514,24 +508,28 @@ class Trainer: if self.args.evaluate_during_training: self.evaluate() - if self.is_world_master(): - if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0: - # In all cases (even distributed/parallel), self.model is always a reference - # to the model we want to save. - if hasattr(model, "module"): - assert model.module is self.model - else: - assert model is self.model - # Save model checkpoint - output_dir = os.path.join( - self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}" - ) + if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0: + # In all cases (even distributed/parallel), self.model is always a reference + # to the model we want to save. + if hasattr(model, "module"): + assert model.module is self.model + else: + assert model is self.model + # Save model checkpoint + output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}") - self.save_model(output_dir) + self.save_model(output_dir) + + if self.is_world_master(): self._rotate_checkpoints() + + if is_tpu_available(): + xm.rendezvous("saving_optimizer_states") + xm.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) + xm.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) + elif self.is_world_master(): torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) - logger.info("Saving optimizer and scheduler states to %s", output_dir) if self.args.max_steps > 0 and self.global_step > self.args.max_steps: epoch_iterator.close() @@ -713,6 +711,7 @@ class Trainer: In that case, this method will also return metrics, like in evaluate(). """ test_dataloader = self.get_test_dataloader(test_dataset) + return self._prediction_loop(test_dataloader, description="Prediction") def _prediction_loop( @@ -735,10 +734,7 @@ class Trainer: # Note: in torch.distributed mode, there's no point in wrapping the model # inside a DistributedDataParallel as we'll be under `no_grad` anyways. - if is_tpu_available(): - batch_size = dataloader._loader._loader.batch_size - else: - batch_size = dataloader.batch_size + batch_size = dataloader.batch_size logger.info("***** Running %s *****", description) logger.info(" Num examples = %d", self.num_examples(dataloader)) logger.info(" Batch size = %d", batch_size) @@ -747,6 +743,9 @@ class Trainer: label_ids: torch.Tensor = None model.eval() + if is_tpu_available(): + dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device) + for inputs in tqdm(dataloader, desc=description): has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])