Revert #4446 Since it introduces a new dependency
Some checks failed
GitHub-hosted runner / check_code_quality (push) Has been cancelled

This commit is contained in:
Lysandre
2020-05-22 10:49:45 -04:00
parent e0db6bbd65
commit 10d72390c0

View File

@@ -11,7 +11,6 @@ from typing import Callable, Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from packaging import version
from torch import nn from torch import nn
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset from torch.utils.data.dataset import Dataset
@@ -495,11 +494,8 @@ class Trainer:
): ):
logs: Dict[str, float] = {} logs: Dict[str, float] = {}
logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
# backward compatibility for pytorch schedulers
logs["learning_rate"] = ( logs["learning_rate"] = (
scheduler.get_last_lr()[0] scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else scheduler.get_lr()[0]
) )
logging_loss = tr_loss logging_loss = tr_loss