diff --git a/run_classifier_pytorch.py b/run_classifier_pytorch.py index 7b3e28a4c0..f8cf4af808 100644 --- a/run_classifier_pytorch.py +++ b/run_classifier_pytorch.py @@ -531,6 +531,7 @@ def main(): loss, _ = model(input_ids, segment_ids, input_mask, label_ids) total_tr_loss += loss.item() nb_tr_examples += input_ids.size(0) + model.zero_grad() loss.backward() optimizer.step() global_step += 1 diff --git a/run_squad_pytorch.py b/run_squad_pytorch.py index 0f1c4bce35..d5b771b91a 100644 --- a/run_squad_pytorch.py +++ b/run_squad_pytorch.py @@ -856,6 +856,7 @@ def main(): logger.info("HHHHH Forward") loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions) + model.zero_grad() logger.info("HHHHH Backward") loss.backward() logger.info("HHHHH Loading data")