From d1ad4bff445d86fcf2700b9317bf6c029f86788a Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 2 Nov 2020 10:26:37 -0500 Subject: [PATCH] Fix bad import with PyTorch <= 1.4.1 (#8237) --- src/transformers/trainer_pt_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index f4cf668e42..f19edba609 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -23,7 +23,7 @@ from typing import List, Optional, Union import numpy as np import torch -from torch.optim.lr_scheduler import SAVE_STATE_WARNING +from packaging import version from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import RandomSampler, Sampler @@ -34,6 +34,11 @@ from .utils import logging if is_torch_tpu_available(): import torch_xla.core.xla_model as xm +if version.parse(torch.__version__) <= version.parse("1.4.1"): + SAVE_STATE_WARNING = "" +else: + from torch.optim.lr_scheduler import SAVE_STATE_WARNING + logger = logging.get_logger(__name__)