TPU hangs when saving optimizer/scheduler (#4467)
* TPU hangs when saving optimizer/scheduler * Style * ParallelLoader is not a DataLoader * Style * Addressing @julien-c's comments
This commit is contained in:
@@ -242,9 +242,6 @@ class Trainer:
|
|||||||
collate_fn=self.data_collator.collate_batch,
|
collate_fn=self.data_collator.collate_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_tpu_available():
|
|
||||||
data_loader = pl.ParallelLoader(data_loader, [self.args.device]).per_device_loader(self.args.device)
|
|
||||||
|
|
||||||
return data_loader
|
return data_loader
|
||||||
|
|
||||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||||
@@ -269,9 +266,6 @@ class Trainer:
|
|||||||
collate_fn=self.data_collator.collate_batch,
|
collate_fn=self.data_collator.collate_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_tpu_available():
|
|
||||||
data_loader = pl.ParallelLoader(data_loader, [self.args.device]).per_device_loader(self.args.device)
|
|
||||||
|
|
||||||
return data_loader
|
return data_loader
|
||||||
|
|
||||||
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
|
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
|
||||||
@@ -292,9 +286,6 @@ class Trainer:
|
|||||||
collate_fn=self.data_collator.collate_batch,
|
collate_fn=self.data_collator.collate_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_tpu_available():
|
|
||||||
data_loader = pl.ParallelLoader(data_loader, [self.args.device]).per_device_loader(self.args.device)
|
|
||||||
|
|
||||||
return data_loader
|
return data_loader
|
||||||
|
|
||||||
def get_optimizers(
|
def get_optimizers(
|
||||||
@@ -351,15 +342,11 @@ class Trainer:
|
|||||||
self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
|
self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
|
||||||
)
|
)
|
||||||
|
|
||||||
def num_examples(self, dataloader: Union[DataLoader, "pl.PerDeviceLoader"]) -> int:
|
def num_examples(self, dataloader: DataLoader) -> int:
|
||||||
"""
|
"""
|
||||||
Helper to get num of examples from a DataLoader, by accessing its Dataset.
|
Helper to get num of examples from a DataLoader, by accessing its Dataset.
|
||||||
"""
|
"""
|
||||||
if is_tpu_available():
|
return len(dataloader.dataset)
|
||||||
assert isinstance(dataloader, pl.PerDeviceLoader)
|
|
||||||
return len(dataloader._loader._loader.dataset)
|
|
||||||
else:
|
|
||||||
return len(dataloader.dataset)
|
|
||||||
|
|
||||||
def train(self, model_path: Optional[str] = None):
|
def train(self, model_path: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
@@ -466,7 +453,14 @@ class Trainer:
|
|||||||
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
|
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
|
||||||
train_dataloader.sampler.set_epoch(epoch)
|
train_dataloader.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master())
|
if is_tpu_available():
|
||||||
|
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
|
||||||
|
self.args.device
|
||||||
|
)
|
||||||
|
epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_master())
|
||||||
|
else:
|
||||||
|
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master())
|
||||||
|
|
||||||
for step, inputs in enumerate(epoch_iterator):
|
for step, inputs in enumerate(epoch_iterator):
|
||||||
|
|
||||||
# Skip past any already trained steps if resuming training
|
# Skip past any already trained steps if resuming training
|
||||||
@@ -514,24 +508,28 @@ class Trainer:
|
|||||||
if self.args.evaluate_during_training:
|
if self.args.evaluate_during_training:
|
||||||
self.evaluate()
|
self.evaluate()
|
||||||
|
|
||||||
if self.is_world_master():
|
if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
|
||||||
if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
|
# In all cases (even distributed/parallel), self.model is always a reference
|
||||||
# In all cases (even distributed/parallel), self.model is always a reference
|
# to the model we want to save.
|
||||||
# to the model we want to save.
|
if hasattr(model, "module"):
|
||||||
if hasattr(model, "module"):
|
assert model.module is self.model
|
||||||
assert model.module is self.model
|
else:
|
||||||
else:
|
assert model is self.model
|
||||||
assert model is self.model
|
# Save model checkpoint
|
||||||
# Save model checkpoint
|
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")
|
||||||
output_dir = os.path.join(
|
|
||||||
self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.save_model(output_dir)
|
self.save_model(output_dir)
|
||||||
|
|
||||||
|
if self.is_world_master():
|
||||||
self._rotate_checkpoints()
|
self._rotate_checkpoints()
|
||||||
|
|
||||||
|
if is_tpu_available():
|
||||||
|
xm.rendezvous("saving_optimizer_states")
|
||||||
|
xm.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||||
|
xm.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||||
|
elif self.is_world_master():
|
||||||
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||||
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||||
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
|
||||||
|
|
||||||
if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
|
if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
|
||||||
epoch_iterator.close()
|
epoch_iterator.close()
|
||||||
@@ -713,6 +711,7 @@ class Trainer:
|
|||||||
In that case, this method will also return metrics, like in evaluate().
|
In that case, this method will also return metrics, like in evaluate().
|
||||||
"""
|
"""
|
||||||
test_dataloader = self.get_test_dataloader(test_dataset)
|
test_dataloader = self.get_test_dataloader(test_dataset)
|
||||||
|
|
||||||
return self._prediction_loop(test_dataloader, description="Prediction")
|
return self._prediction_loop(test_dataloader, description="Prediction")
|
||||||
|
|
||||||
def _prediction_loop(
|
def _prediction_loop(
|
||||||
@@ -735,10 +734,7 @@ class Trainer:
|
|||||||
# Note: in torch.distributed mode, there's no point in wrapping the model
|
# Note: in torch.distributed mode, there's no point in wrapping the model
|
||||||
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
|
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
|
||||||
|
|
||||||
if is_tpu_available():
|
batch_size = dataloader.batch_size
|
||||||
batch_size = dataloader._loader._loader.batch_size
|
|
||||||
else:
|
|
||||||
batch_size = dataloader.batch_size
|
|
||||||
logger.info("***** Running %s *****", description)
|
logger.info("***** Running %s *****", description)
|
||||||
logger.info(" Num examples = %d", self.num_examples(dataloader))
|
logger.info(" Num examples = %d", self.num_examples(dataloader))
|
||||||
logger.info(" Batch size = %d", batch_size)
|
logger.info(" Batch size = %d", batch_size)
|
||||||
@@ -747,6 +743,9 @@ class Trainer:
|
|||||||
label_ids: torch.Tensor = None
|
label_ids: torch.Tensor = None
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
if is_tpu_available():
|
||||||
|
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
|
||||||
|
|
||||||
for inputs in tqdm(dataloader, desc=description):
|
for inputs in tqdm(dataloader, desc=description):
|
||||||
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
|
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user