From 9d7d0005b046a95d9d59354714bb6c3547a612fe Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 7 Dec 2020 21:59:55 -0800 Subject: [PATCH] [training] SAVE_STATE_WARNING was removed in pytorch (#8979) * [training] SAVE_STATE_WARNING was removed in pytorch FYI `SAVE_STATE_WARNING` has been removed 3 days ago: pytorch/pytorch#46813 Fixes: #8232 @sgugger * style, but add () to prevent autoformatters from botching it * switch to try/except * cleanup --- src/transformers/trainer_pt_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index cb3d4a5bfe..5cb45eb7bd 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -23,7 +23,6 @@ from typing import List, Optional, Union import numpy as np import torch -from packaging import version from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import RandomSampler, Sampler @@ -34,10 +33,11 @@ from .utils import logging if is_torch_tpu_available(): import torch_xla.core.xla_model as xm -if version.parse(torch.__version__) <= version.parse("1.4.1"): - SAVE_STATE_WARNING = "" -else: +# this is used to supress an undesired warning emitted by pytorch versions 1.4.2-1.7.0 +try: from torch.optim.lr_scheduler import SAVE_STATE_WARNING +except ImportError: + SAVE_STATE_WARNING = "" logger = logging.get_logger(__name__)