BertAdam schedule objects
This commit is contained in:
@@ -18,7 +18,7 @@ from .modeling_gpt2 import (GPT2Config, GPT2Model,
|
|||||||
GPT2LMHeadModel, GPT2DoubleHeadsModel,
|
GPT2LMHeadModel, GPT2DoubleHeadsModel,
|
||||||
load_tf_weights_in_gpt2)
|
load_tf_weights_in_gpt2)
|
||||||
|
|
||||||
from .optimization import BertAdam
|
from .optimization import *
|
||||||
from .optimization_openai import OpenAIAdam
|
from .optimization_openai import OpenAIAdam
|
||||||
|
|
||||||
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path
|
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path
|
||||||
|
|||||||
@@ -24,6 +24,9 @@ import logging
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["LRSchedule", "WarmupLinearSchedule", "WarmupConstantSchedule", "WarmupCosineSchedule", "BertAdam"]
|
||||||
|
|
||||||
|
|
||||||
class LRSchedule(object):
|
class LRSchedule(object):
|
||||||
warn_t_total = False
|
warn_t_total = False
|
||||||
def __init__(self, warmup=0.002, t_total=-1, **kw):
|
def __init__(self, warmup=0.002, t_total=-1, **kw):
|
||||||
@@ -83,32 +86,7 @@ class WarmupLinearSchedule(LRSchedule):
|
|||||||
if progress < self.warmup:
|
if progress < self.warmup:
|
||||||
return progress / self.warmup
|
return progress / self.warmup
|
||||||
return max((progress - 1.) / (self.warmup - 1.), 0)
|
return max((progress - 1.) / (self.warmup - 1.), 0)
|
||||||
#
|
|
||||||
#
|
|
||||||
# def warmup_cosine(x, warmup=0.002):
|
|
||||||
# if x < warmup:
|
|
||||||
# return x/warmup
|
|
||||||
# return 0.5 * (1.0 + torch.cos(math.pi * x))
|
|
||||||
#
|
|
||||||
# def warmup_constant(x, warmup=0.002):
|
|
||||||
# """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
|
|
||||||
# Learning rate is 1. afterwards. """
|
|
||||||
# if x < warmup:
|
|
||||||
# return x/warmup
|
|
||||||
# return 1.0
|
|
||||||
#
|
|
||||||
# def warmup_linear(x, warmup=0.002):
|
|
||||||
# """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
|
|
||||||
# After `t_total`-th training step, learning rate is zero. """
|
|
||||||
# if x < warmup:
|
|
||||||
# return x/warmup
|
|
||||||
# return max((x-1.)/(warmup-1.), 0)
|
|
||||||
#
|
|
||||||
# SCHEDULES = {
|
|
||||||
# 'warmup_cosine': warmup_cosine,
|
|
||||||
# 'warmup_constant': warmup_constant,
|
|
||||||
# 'warmup_linear': warmup_linear,
|
|
||||||
# }
|
|
||||||
|
|
||||||
SCHEDULES = {
|
SCHEDULES = {
|
||||||
None: LRSchedule,
|
None: LRSchedule,
|
||||||
@@ -126,7 +104,9 @@ class BertAdam(Optimizer):
|
|||||||
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
|
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
|
||||||
t_total: total number of training steps for the learning
|
t_total: total number of training steps for the learning
|
||||||
rate schedule, -1 means constant learning rate. Default: -1
|
rate schedule, -1 means constant learning rate. Default: -1
|
||||||
schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
|
schedule: schedule to use for the warmup (see above).
|
||||||
|
Can be 'warmup_linear', 'warmup_constant', 'warmup_cosine', or a LRSchedule object.
|
||||||
|
Default: 'warmup_linear'
|
||||||
b1: Adams b1. Default: 0.9
|
b1: Adams b1. Default: 0.9
|
||||||
b2: Adams b2. Default: 0.999
|
b2: Adams b2. Default: 0.999
|
||||||
e: Adams epsilon. Default: 1e-6
|
e: Adams epsilon. Default: 1e-6
|
||||||
@@ -147,9 +127,13 @@ class BertAdam(Optimizer):
|
|||||||
if not e >= 0.0:
|
if not e >= 0.0:
|
||||||
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
|
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
|
||||||
# initialize schedule object
|
# initialize schedule object
|
||||||
schedule_type = SCHEDULES[schedule]
|
if not isinstance(schedule, LRSchedule):
|
||||||
sched = schedule_type(warmup=warmup, t_total=t_total)
|
schedule_type = SCHEDULES[schedule]
|
||||||
defaults = dict(lr=lr, schedule=sched,
|
schedule = schedule_type(warmup=warmup, t_total=t_total)
|
||||||
|
else:
|
||||||
|
if warmup != -1 or t_total != -1:
|
||||||
|
logger.warning("Non-default warmup and t_total are ineffective when LRSchedule object is provided.")
|
||||||
|
defaults = dict(lr=lr, schedule=schedule,
|
||||||
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
|
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
|
||||||
max_grad_norm=max_grad_norm)
|
max_grad_norm=max_grad_norm)
|
||||||
super(BertAdam, self).__init__(params, defaults)
|
super(BertAdam, self).__init__(params, defaults)
|
||||||
@@ -163,7 +147,7 @@ class BertAdam(Optimizer):
|
|||||||
return [0]
|
return [0]
|
||||||
|
|
||||||
lr_scheduled = group['lr']
|
lr_scheduled = group['lr']
|
||||||
lr_scheduled *= group['schedule'](state['step'])
|
lr_scheduled *= group['schedule'].get_lr(state['step'])
|
||||||
|
|
||||||
lr.append(lr_scheduled)
|
lr.append(lr_scheduled)
|
||||||
return lr
|
return lr
|
||||||
@@ -221,7 +205,7 @@ class BertAdam(Optimizer):
|
|||||||
update += group['weight_decay'] * p.data
|
update += group['weight_decay'] * p.data
|
||||||
|
|
||||||
lr_scheduled = group['lr']
|
lr_scheduled = group['lr']
|
||||||
lr_scheduled *= group['schedule'](state['step'])
|
lr_scheduled *= group['schedule'].get_lr(state['step'])
|
||||||
|
|
||||||
update_with_lr = lr_scheduled * update
|
update_with_lr = lr_scheduled * update
|
||||||
p.data.add_(-update_with_lr)
|
p.data.add_(-update_with_lr)
|
||||||
|
|||||||
Reference in New Issue
Block a user