From 3571187ef6f07a7ba63ee5b355e312f2fbfaaab7 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 15 Apr 2019 16:43:56 +0200 Subject: [PATCH] fix saving models in distributed setting examples --- examples/run_classifier.py | 1 + examples/run_squad.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index 112be6fbcb..4994118467 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -859,6 +859,7 @@ def main(): optimizer.zero_grad() global_step += 1 + if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): # Save a trained model, configuration and tokenizer model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self diff --git a/examples/run_squad.py b/examples/run_squad.py index cd85219f5f..410fd85298 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -1020,7 +1020,7 @@ def main(): optimizer.zero_grad() global_step += 1 - if args.do_train: + if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): # Save a trained model, configuration and tokenizer model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self