Re-apply #4446 + add packaging dependency

As discussed w/ @lysandrejik

packaging is maintained by PyPA (the Python Packaging Authority), and should be lightweight and stable
This commit is contained in:
Julien Chaumond
2020-05-22 17:27:47 -04:00
parent e6aeb0d3e8
commit 2c1ebb8b50
2 changed files with 9 additions and 1 deletions

View File

@@ -111,6 +111,8 @@ setup(
"tokenizers == 0.7.0", "tokenizers == 0.7.0",
# dataclasses for Python versions that don't have it # dataclasses for Python versions that don't have it
"dataclasses;python_version<'3.7'", "dataclasses;python_version<'3.7'",
# utilities from PyPA to e.g. compare versions
"packaging",
# filesystem locks e.g. to prevent parallel downloads # filesystem locks e.g. to prevent parallel downloads
"filelock", "filelock",
# for downloading models over HTTPS # for downloading models over HTTPS

View File

@@ -11,6 +11,7 @@ from typing import Callable, Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from packaging import version
from torch import nn from torch import nn
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset from torch.utils.data.dataset import Dataset
@@ -494,7 +495,12 @@ class Trainer:
): ):
logs: Dict[str, float] = {} logs: Dict[str, float] = {}
logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
logs["learning_rate"] = scheduler.get_last_lr()[0] # backward compatibility for pytorch schedulers
logs["learning_rate"] = (
scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else scheduler.get_lr()[0]
)
logging_loss = tr_loss logging_loss = tr_loss
self._log(logs) self._log(logs)