From 9de4afa8974b5afbaf61c41c4186eef6546932d4 Mon Sep 17 00:00:00 2001 From: Rakesh Chada Date: Mon, 18 May 2020 17:17:36 -0700 Subject: [PATCH] Make get_last_lr in trainer backward compatible (#4446) * makes fetching last learning late in trainer backward compatible * split comment to multiple lines * fixes black styling issue * uses version to create a more explicit logic --- src/transformers/trainer.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 251f0dd4bc..f836987c28 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -10,6 +10,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch +from packaging import version from torch import nn from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Dataset @@ -440,7 +441,14 @@ class Trainer: ): logs: Dict[str, float] = {} logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps - logs["learning_rate"] = scheduler.get_last_lr()[0] + # maintaining backward compatibility. + # could use "scheduler.get_last_lr()[0]" instead for pytorch >= 1.4.0 + logs["learning_rate"] = ( + scheduler.get_last_lr()[0] + if version.parse(torch.__version__) >= version.parse("1.4") + else scheduler.get_lr()[0] + ) + logging_loss = tr_loss self._log(logs)