Fix bad import with PyTorch <= 1.4.1 (#8237)
This commit is contained in:
@@ -23,7 +23,7 @@ from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import SAVE_STATE_WARNING
|
||||
from packaging import version
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import RandomSampler, Sampler
|
||||
|
||||
@@ -34,6 +34,11 @@ from .utils import logging
|
||||
if is_torch_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
if version.parse(torch.__version__) <= version.parse("1.4.1"):
|
||||
SAVE_STATE_WARNING = ""
|
||||
else:
|
||||
from torch.optim.lr_scheduler import SAVE_STATE_WARNING
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user