Make get_last_lr in trainer backward compatible (#4446)

* makes fetching last learning late in trainer backward compatible

* split comment to multiple lines

* fixes black styling issue

* uses version to create a more explicit logic
This commit is contained in:
Rakesh Chada
2020-05-18 17:17:36 -07:00
committed by GitHub
parent 42e8fbfc51
commit 9de4afa897

View File

@@ -10,6 +10,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from packaging import version
from torch import nn from torch import nn
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset from torch.utils.data.dataset import Dataset
@@ -440,7 +441,14 @@ class Trainer:
): ):
logs: Dict[str, float] = {} logs: Dict[str, float] = {}
logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
logs["learning_rate"] = scheduler.get_last_lr()[0] # maintaining backward compatibility.
# could use "scheduler.get_last_lr()[0]" instead for pytorch >= 1.4.0
logs["learning_rate"] = (
scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else scheduler.get_lr()[0]
)
logging_loss = tr_loss logging_loss = tr_loss
self._log(logs) self._log(logs)