updating gather function with gather_for_metrics in run_wav2vec2_pretraining (#18877)

Co-authored-by: Arun Rajaram <arunrajaram@Aruns-MacBook-Pro.local>
This commit is contained in:
arun99481
2022-09-06 17:06:37 +05:30
committed by GitHub
parent 734b7e2a5a
commit 3b19c0317b

View File

@@ -596,7 +596,7 @@ def main():
# make sure that `num_losses` is summed for distributed training # make sure that `num_losses` is summed for distributed training
# and average gradients over losses of all devices # and average gradients over losses of all devices
if accelerator.state.num_processes > 1: if accelerator.state.num_processes > 1:
num_losses = accelerator.gather(num_losses).sum() num_losses = accelerator.gather_for_metrics(num_losses).sum()
gradient_multiplier = accelerator.state.num_processes / num_losses gradient_multiplier = accelerator.state.num_processes / num_losses
multiply_grads(model.module.parameters(), gradient_multiplier) multiply_grads(model.module.parameters(), gradient_multiplier)
else: else:
@@ -647,10 +647,10 @@ def main():
outputs.diversity_loss.detach() outputs.diversity_loss.detach()
if accelerator.state.num_processes > 1: if accelerator.state.num_processes > 1:
loss = accelerator.gather(loss).sum() loss = accelerator.gather_for_metrics(loss).sum()
outputs.contrastive_loss = accelerator.gather(outputs.contrastive_loss).sum() outputs.contrastive_loss = accelerator.gather_for_metrics(outputs.contrastive_loss).sum()
outputs.diversity_loss = accelerator.gather(outputs.diversity_loss).sum() outputs.diversity_loss = accelerator.gather_for_metrics(outputs.diversity_loss).sum()
percent_masked = accelerator.gather(percent_masked).sum() percent_masked = accelerator.gather_for_metrics(percent_masked).sum()
train_logs = { train_logs = {
"loss": (loss * args.gradient_accumulation_steps) / num_losses, "loss": (loss * args.gradient_accumulation_steps) / num_losses,
@@ -713,7 +713,7 @@ def main():
# sum over devices in multi-processing # sum over devices in multi-processing
if accelerator.num_processes > 1: if accelerator.num_processes > 1:
val_logs = {k: accelerator.gather(v).sum() for k, v in val_logs.items()} val_logs = {k: accelerator.gather_for_metrics(v).sum() for k, v in val_logs.items()}
val_logs = {k: v / val_logs["val_num_losses"] for k, v in val_logs.items()} val_logs = {k: v / val_logs["val_num_losses"] for k, v in val_logs.items()}