Revert #4446 Since it introduces a new dependency
Some checks failed
GitHub-hosted runner / check_code_quality (push) Has been cancelled

This commit is contained in:
Lysandre
2020-05-22 10:49:45 -04:00
parent e0db6bbd65
commit 10d72390c0

View File

@@ -11,7 +11,6 @@ from typing import Callable, Dict, List, Optional, Tuple
import numpy as np
import torch
from packaging import version
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
@@ -495,11 +494,8 @@ class Trainer:
):
logs: Dict[str, float] = {}
logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
# backward compatibility for pytorch schedulers
logs["learning_rate"] = (
scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else scheduler.get_lr()[0]
)
logging_loss = tr_loss