From cb76c1ddd39bc6c396ee031ca6c21f63e8b8cd7b Mon Sep 17 00:00:00 2001 From: thomwolf Date: Sat, 3 Nov 2018 17:40:12 +0100 Subject: [PATCH] add model.zero_grad() --- run_classifier_pytorch.py | 1 + run_squad_pytorch.py | 1 + 2 files changed, 2 insertions(+) diff --git a/run_classifier_pytorch.py b/run_classifier_pytorch.py index 7b3e28a4c0..f8cf4af808 100644 --- a/run_classifier_pytorch.py +++ b/run_classifier_pytorch.py @@ -531,6 +531,7 @@ def main(): loss, _ = model(input_ids, segment_ids, input_mask, label_ids) total_tr_loss += loss.item() nb_tr_examples += input_ids.size(0) + model.zero_grad() loss.backward() optimizer.step() global_step += 1 diff --git a/run_squad_pytorch.py b/run_squad_pytorch.py index 0f1c4bce35..d5b771b91a 100644 --- a/run_squad_pytorch.py +++ b/run_squad_pytorch.py @@ -856,6 +856,7 @@ def main(): logger.info("HHHHH Forward") loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions) + model.zero_grad() logger.info("HHHHH Backward") loss.backward() logger.info("HHHHH Loading data")