python 2 compat
This commit is contained in:
@@ -20,11 +20,18 @@ from torch.optim import Optimizer
|
|||||||
from torch.optim.optimizer import required
|
from torch.optim.optimizer import required
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
import abc
|
||||||
|
import sys
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 4):
|
||||||
|
ABC = abc.ABC
|
||||||
|
else:
|
||||||
|
ABC = abc.ABCMeta('ABC', (), {})
|
||||||
|
|
||||||
|
|
||||||
class _LRSchedule(ABC):
|
class _LRSchedule(ABC):
|
||||||
""" Parent of all LRSchedules here. """
|
""" Parent of all LRSchedules here. """
|
||||||
warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense
|
warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense
|
||||||
@@ -62,7 +69,7 @@ class _LRSchedule(ABC):
|
|||||||
# end warning
|
# end warning
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@abstractmethod
|
@abc.abstractmethod
|
||||||
def get_lr_(self, progress):
|
def get_lr_(self, progress):
|
||||||
"""
|
"""
|
||||||
:param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress
|
:param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress
|
||||||
|
|||||||
Reference in New Issue
Block a user