[Trainer/Deepspeed] handle get_last_lr() before first step() (#10362)

* handle get_last_lr() before first step()

* abstract away the lr getting logic

* cleanup

* add test

* move to utils
This commit is contained in:
Stas Bekman
2021-02-23 17:42:25 -08:00
committed by GitHub
parent 4a1ab7cb6c
commit 3437d12134
3 changed files with 52 additions and 6 deletions

View File

@@ -82,6 +82,7 @@ from .trainer_pt_utils import (
SequentialDistributedSampler,
distributed_broadcast_scalars,
distributed_concat,
get_learning_rate,
nested_concat,
nested_detach,
nested_numpify,
@@ -1129,12 +1130,8 @@ class Trainer:
tr_loss -= tr_loss
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
# backward compatibility for pytorch schedulers
logs["learning_rate"] = (
self.lr_scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else self.lr_scheduler.get_lr()[0]
)
logs["learning_rate"] = get_learning_rate(self)
self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step

View File

@@ -24,6 +24,7 @@ from typing import Iterator, List, Optional, Union
import numpy as np
import torch
from packaging import version
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler
@@ -262,6 +263,29 @@ def _get_first_shape(arrays):
return arrays.shape
def get_learning_rate(trainer):
if trainer.deepspeed:
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
# not run for the first few dozen steps while loss scale is too large, and thus during
# that time `get_last_lr` will fail if called during that warm up stage, so work around it:
try:
last_lr = trainer.lr_scheduler.get_last_lr()[0]
except AssertionError as e:
if "need to call step" in str(e):
logger.warn("tried to get lr value before scheduler/optimizer started stepping, returning lr=0")
last_lr = 0
else:
raise
else:
last_lr = (
# backward compatibility for pytorch schedulers
trainer.lr_scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else trainer.lr_scheduler.get_lr()[0]
)
return last_lr
class DistributedTensorGatherer:
"""
A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks.