TPU hangs when saving optimizer/scheduler (#4467)
* TPU hangs when saving optimizer/scheduler * Style * ParallelLoader is not a DataLoader * Style * Addressing @julien-c's comments
This commit is contained in:
@@ -242,9 +242,6 @@ class Trainer:
|
||||
collate_fn=self.data_collator.collate_batch,
|
||||
)
|
||||
|
||||
if is_tpu_available():
|
||||
data_loader = pl.ParallelLoader(data_loader, [self.args.device]).per_device_loader(self.args.device)
|
||||
|
||||
return data_loader
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||
@@ -269,9 +266,6 @@ class Trainer:
|
||||
collate_fn=self.data_collator.collate_batch,
|
||||
)
|
||||
|
||||
if is_tpu_available():
|
||||
data_loader = pl.ParallelLoader(data_loader, [self.args.device]).per_device_loader(self.args.device)
|
||||
|
||||
return data_loader
|
||||
|
||||
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
|
||||
@@ -292,9 +286,6 @@ class Trainer:
|
||||
collate_fn=self.data_collator.collate_batch,
|
||||
)
|
||||
|
||||
if is_tpu_available():
|
||||
data_loader = pl.ParallelLoader(data_loader, [self.args.device]).per_device_loader(self.args.device)
|
||||
|
||||
return data_loader
|
||||
|
||||
def get_optimizers(
|
||||
@@ -351,15 +342,11 @@ class Trainer:
|
||||
self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
|
||||
)
|
||||
|
||||
def num_examples(self, dataloader: Union[DataLoader, "pl.PerDeviceLoader"]) -> int:
|
||||
def num_examples(self, dataloader: DataLoader) -> int:
|
||||
"""
|
||||
Helper to get num of examples from a DataLoader, by accessing its Dataset.
|
||||
"""
|
||||
if is_tpu_available():
|
||||
assert isinstance(dataloader, pl.PerDeviceLoader)
|
||||
return len(dataloader._loader._loader.dataset)
|
||||
else:
|
||||
return len(dataloader.dataset)
|
||||
return len(dataloader.dataset)
|
||||
|
||||
def train(self, model_path: Optional[str] = None):
|
||||
"""
|
||||
@@ -466,7 +453,14 @@ class Trainer:
|
||||
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
|
||||
train_dataloader.sampler.set_epoch(epoch)
|
||||
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master())
|
||||
if is_tpu_available():
|
||||
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
|
||||
self.args.device
|
||||
)
|
||||
epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_master())
|
||||
else:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master())
|
||||
|
||||
for step, inputs in enumerate(epoch_iterator):
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
@@ -514,24 +508,28 @@ class Trainer:
|
||||
if self.args.evaluate_during_training:
|
||||
self.evaluate()
|
||||
|
||||
if self.is_world_master():
|
||||
if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
|
||||
# In all cases (even distributed/parallel), self.model is always a reference
|
||||
# to the model we want to save.
|
||||
if hasattr(model, "module"):
|
||||
assert model.module is self.model
|
||||
else:
|
||||
assert model is self.model
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(
|
||||
self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
|
||||
)
|
||||
if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
|
||||
# In all cases (even distributed/parallel), self.model is always a reference
|
||||
# to the model we want to save.
|
||||
if hasattr(model, "module"):
|
||||
assert model.module is self.model
|
||||
else:
|
||||
assert model is self.model
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")
|
||||
|
||||
self.save_model(output_dir)
|
||||
self.save_model(output_dir)
|
||||
|
||||
if self.is_world_master():
|
||||
self._rotate_checkpoints()
|
||||
|
||||
if is_tpu_available():
|
||||
xm.rendezvous("saving_optimizer_states")
|
||||
xm.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
xm.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
elif self.is_world_master():
|
||||
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
||||
|
||||
if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
|
||||
epoch_iterator.close()
|
||||
@@ -713,6 +711,7 @@ class Trainer:
|
||||
In that case, this method will also return metrics, like in evaluate().
|
||||
"""
|
||||
test_dataloader = self.get_test_dataloader(test_dataset)
|
||||
|
||||
return self._prediction_loop(test_dataloader, description="Prediction")
|
||||
|
||||
def _prediction_loop(
|
||||
@@ -735,10 +734,7 @@ class Trainer:
|
||||
# Note: in torch.distributed mode, there's no point in wrapping the model
|
||||
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
|
||||
|
||||
if is_tpu_available():
|
||||
batch_size = dataloader._loader._loader.batch_size
|
||||
else:
|
||||
batch_size = dataloader.batch_size
|
||||
batch_size = dataloader.batch_size
|
||||
logger.info("***** Running %s *****", description)
|
||||
logger.info(" Num examples = %d", self.num_examples(dataloader))
|
||||
logger.info(" Batch size = %d", batch_size)
|
||||
@@ -747,6 +743,9 @@ class Trainer:
|
||||
label_ids: torch.Tensor = None
|
||||
model.eval()
|
||||
|
||||
if is_tpu_available():
|
||||
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
|
||||
|
||||
for inputs in tqdm(dataloader, desc=description):
|
||||
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user