From 2c1ebb8b507c19c75af5084b6c73e0b003c9eda6 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Fri, 22 May 2020 17:27:47 -0400 Subject: [PATCH] Re-apply #4446 + add packaging dependency As discussed w/ @lysandrejik packaging is maintained by PyPA (the Python Packaging Authority), and should be lightweight and stable --- setup.py | 2 ++ src/transformers/trainer.py | 8 +++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 1b989e9277..19dad3b332 100644 --- a/setup.py +++ b/setup.py @@ -111,6 +111,8 @@ setup( "tokenizers == 0.7.0", # dataclasses for Python versions that don't have it "dataclasses;python_version<'3.7'", + # utilities from PyPA to e.g. compare versions + "packaging", # filesystem locks e.g. to prevent parallel downloads "filelock", # for downloading models over HTTPS diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 9a0f74d875..4362832644 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -11,6 +11,7 @@ from typing import Callable, Dict, List, Optional, Tuple import numpy as np import torch +from packaging import version from torch import nn from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Dataset @@ -494,7 +495,12 @@ class Trainer: ): logs: Dict[str, float] = {} logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps - logs["learning_rate"] = scheduler.get_last_lr()[0] + # backward compatibility for pytorch schedulers + logs["learning_rate"] = ( + scheduler.get_last_lr()[0] + if version.parse(torch.__version__) >= version.parse("1.4") + else scheduler.get_lr()[0] + ) logging_loss = tr_loss self._log(logs)