python 2 compat
This commit is contained in:
@@ -20,11 +20,18 @@ from torch.optim import Optimizer
|
||||
from torch.optim.optimizer import required
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
import abc
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if sys.version_info >= (3, 4):
|
||||
ABC = abc.ABC
|
||||
else:
|
||||
ABC = abc.ABCMeta('ABC', (), {})
|
||||
|
||||
|
||||
class _LRSchedule(ABC):
|
||||
""" Parent of all LRSchedules here. """
|
||||
warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense
|
||||
@@ -62,7 +69,7 @@ class _LRSchedule(ABC):
|
||||
# end warning
|
||||
return ret
|
||||
|
||||
@abstractmethod
|
||||
@abc.abstractmethod
|
||||
def get_lr_(self, progress):
|
||||
"""
|
||||
:param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress
|
||||
|
||||
Reference in New Issue
Block a user