From 290633b882a706c862f3407cb9779063ce443e1c Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Sun, 4 Nov 2018 17:31:50 -0500 Subject: [PATCH] Fix `args.gradient_accumulation_steps` used before assigment. --- run_classifier.py | 6 +++++- run_squad.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/run_classifier.py b/run_classifier.py index 2c82bcd4c1..b5290afd12 100644 --- a/run_classifier.py +++ b/run_classifier.py @@ -404,6 +404,10 @@ def main(): type=int, default=42, help="random seed for initialization") + parser.add_argument('--gradient_accumulation_steps', + type=int, + default=1, + help="Number of updates steps to accumualte before performing a backward/update pass.") args = parser.parse_args() processors = { @@ -469,7 +473,7 @@ def main(): model = BertForSequenceClassification(bert_config, len(label_list)) if args.init_checkpoint is not None: - model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) + model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) model.to(device) if args.local_rank != -1: diff --git a/run_squad.py b/run_squad.py index 3a961079c9..8a69e057e5 100644 --- a/run_squad.py +++ b/run_squad.py @@ -739,7 +739,11 @@ def main(): type=int, default=42, help="random seed for initialization") - + parser.add_argument('--gradient_accumulation_steps', + type=int, + default=1, + help="Number of updates steps to accumualte before performing a backward/update pass.") + args = parser.parse_args() if args.local_rank == -1 or args.no_cuda: