From 5f04aa00edd7fe08d4beae28d831b5c556b0c406 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 9 Nov 2018 11:28:14 +0100 Subject: [PATCH] option to perform optimization and keep the optimizer averages on CPU --- run_squad.py | 37 +++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/run_squad.py b/run_squad.py index 59bd32c7c6..e44044f9a0 100644 --- a/run_squad.py +++ b/run_squad.py @@ -719,7 +719,6 @@ def main(): parser.add_argument("--max_answer_length", default=30, type=int, help="The maximum length of an answer that can be generated. This is needed because the start " "and end predictions are not conditioned on one another.") - parser.add_argument("--verbose_logging", default=False, action='store_true', help="If true, all of the warnings related to data processing will be printed. " "A number of warnings are expected for a normal SQuAD evaluation.") @@ -727,10 +726,6 @@ def main(): default=False, action='store_true', help="Whether not to use CUDA when available") - parser.add_argument("--local_rank", - type=int, - default=-1, - help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, @@ -738,7 +733,16 @@ def main(): parser.add_argument('--gradient_accumulation_steps', type=int, default=1, - help="Number of updates steps to accumualte before performing a backward/update pass.") + help="Number of updates steps to accumulate before performing a backward/update pass.") + parser.add_argument("--local_rank", + type=int, + default=-1, + help="local_rank for distributed training on gpus") + parser.add_argument('--optimize_on_cpu', + default=False, + action='store_true', + help="Whether to perform optimization and keep the optimizer averages on CPU") + args = parser.parse_args() @@ -802,25 +806,26 @@ def main(): model = BertForQuestionAnswering(bert_config) if args.init_checkpoint is not None: model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) - model.to(device) - - if args.local_rank != -1: - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], - output_device=args.local_rank) - elif n_gpu > 1: - model = torch.nn.DataParallel(model) + if not args.optimize_on_cpu: + model.to(device) no_decay = ['bias', 'gamma', 'beta'] optimizer_parameters = [ {'params': [p for n, p in model.named_parameters() if n not in no_decay], 'weight_decay_rate': 0.01}, {'params': [p for n, p in model.named_parameters() if n in no_decay], 'weight_decay_rate': 0.0} ] - optimizer = BERTAdam(optimizer_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_steps) + model.to(device) + if args.local_rank != -1: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], + output_device=args.local_rank) + elif n_gpu > 1: + model = torch.nn.DataParallel(model) + global_step = 0 if args.do_train: train_features = convert_examples_to_features( @@ -862,8 +867,12 @@ def main(): loss = loss / args.gradient_accumulation_steps loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: + if args.optimize_on_cpu: + model.to('cpu') optimizer.step() # We have accumulated enought gradients model.zero_grad() + if args.optimize_on_cpu: + model.to(device) global_step += 1 if args.do_predict: