Fix bad import with PyTorch <= 1.4.1 (#8237)

This commit is contained in:
Sylvain Gugger
2020-11-02 10:26:37 -05:00
committed by GitHub
parent 3c8d401cf6
commit d1ad4bff44

View File

@@ -23,7 +23,7 @@ from typing import List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from torch.optim.lr_scheduler import SAVE_STATE_WARNING from packaging import version
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler from torch.utils.data.sampler import RandomSampler, Sampler
@@ -34,6 +34,11 @@ from .utils import logging
if is_torch_tpu_available(): if is_torch_tpu_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
if version.parse(torch.__version__) <= version.parse("1.4.1"):
SAVE_STATE_WARNING = ""
else:
from torch.optim.lr_scheduler import SAVE_STATE_WARNING
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)