Make get_last_lr in trainer backward compatible (#4446)
* makes fetching last learning late in trainer backward compatible * split comment to multiple lines * fixes black styling issue * uses version to create a more explicit logic
This commit is contained in:
@@ -10,6 +10,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data.dataloader import DataLoader
|
from torch.utils.data.dataloader import DataLoader
|
||||||
from torch.utils.data.dataset import Dataset
|
from torch.utils.data.dataset import Dataset
|
||||||
@@ -440,7 +441,14 @@ class Trainer:
|
|||||||
):
|
):
|
||||||
logs: Dict[str, float] = {}
|
logs: Dict[str, float] = {}
|
||||||
logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
|
logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
|
||||||
logs["learning_rate"] = scheduler.get_last_lr()[0]
|
# maintaining backward compatibility.
|
||||||
|
# could use "scheduler.get_last_lr()[0]" instead for pytorch >= 1.4.0
|
||||||
|
logs["learning_rate"] = (
|
||||||
|
scheduler.get_last_lr()[0]
|
||||||
|
if version.parse(torch.__version__) >= version.parse("1.4")
|
||||||
|
else scheduler.get_lr()[0]
|
||||||
|
)
|
||||||
|
|
||||||
logging_loss = tr_loss
|
logging_loss = tr_loss
|
||||||
|
|
||||||
self._log(logs)
|
self._log(logs)
|
||||||
|
|||||||
Reference in New Issue
Block a user