[training] SAVE_STATE_WARNING was removed in pytorch (#8979)

* [training] SAVE_STATE_WARNING was removed in pytorch

FYI `SAVE_STATE_WARNING` has been removed 3 days ago: pytorch/pytorch#46813

Fixes: #8232

@sgugger

* style, but add () to prevent autoformatters from botching it

* switch to try/except

* cleanup
This commit is contained in:
Stas Bekman
2020-12-07 21:59:55 -08:00
committed by GitHub
parent 2ae7388eee
commit 9d7d0005b0

View File

@@ -23,7 +23,6 @@ from typing import List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from packaging import version
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler from torch.utils.data.sampler import RandomSampler, Sampler
@@ -34,10 +33,11 @@ from .utils import logging
if is_torch_tpu_available(): if is_torch_tpu_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
if version.parse(torch.__version__) <= version.parse("1.4.1"): # this is used to supress an undesired warning emitted by pytorch versions 1.4.2-1.7.0
SAVE_STATE_WARNING = "" try:
else:
from torch.optim.lr_scheduler import SAVE_STATE_WARNING from torch.optim.lr_scheduler import SAVE_STATE_WARNING
except ImportError:
SAVE_STATE_WARNING = ""
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)