Fix bad import with PyTorch <= 1.4.1 (#8237)
This commit is contained in:
@@ -23,7 +23,7 @@ from typing import List, Optional, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.optim.lr_scheduler import SAVE_STATE_WARNING
|
from packaging import version
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from torch.utils.data.sampler import RandomSampler, Sampler
|
from torch.utils.data.sampler import RandomSampler, Sampler
|
||||||
|
|
||||||
@@ -34,6 +34,11 @@ from .utils import logging
|
|||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
if version.parse(torch.__version__) <= version.parse("1.4.1"):
|
||||||
|
SAVE_STATE_WARNING = ""
|
||||||
|
else:
|
||||||
|
from torch.optim.lr_scheduler import SAVE_STATE_WARNING
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user