From 90a41dbe1404f734f6a25bfbaf89be71ba5e4613 Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Sat, 9 Mar 2019 02:23:20 +0100 Subject: [PATCH] BertAdam schedule objects --- pytorch_pretrained_bert/__init__.py | 2 +- pytorch_pretrained_bert/optimization.py | 48 +++++++++---------------- 2 files changed, 17 insertions(+), 33 deletions(-) diff --git a/pytorch_pretrained_bert/__init__.py b/pytorch_pretrained_bert/__init__.py index bd455b8d9c..e82d409ee0 100644 --- a/pytorch_pretrained_bert/__init__.py +++ b/pytorch_pretrained_bert/__init__.py @@ -18,7 +18,7 @@ from .modeling_gpt2 import (GPT2Config, GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel, load_tf_weights_in_gpt2) -from .optimization import BertAdam +from .optimization import * from .optimization_openai import OpenAIAdam from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path diff --git a/pytorch_pretrained_bert/optimization.py b/pytorch_pretrained_bert/optimization.py index 73afc71058..cea35c39e9 100644 --- a/pytorch_pretrained_bert/optimization.py +++ b/pytorch_pretrained_bert/optimization.py @@ -24,6 +24,9 @@ import logging logger = logging.getLogger(__name__) +__all__ = ["LRSchedule", "WarmupLinearSchedule", "WarmupConstantSchedule", "WarmupCosineSchedule", "BertAdam"] + + class LRSchedule(object): warn_t_total = False def __init__(self, warmup=0.002, t_total=-1, **kw): @@ -83,32 +86,7 @@ class WarmupLinearSchedule(LRSchedule): if progress < self.warmup: return progress / self.warmup return max((progress - 1.) / (self.warmup - 1.), 0) -# -# -# def warmup_cosine(x, warmup=0.002): -# if x < warmup: -# return x/warmup -# return 0.5 * (1.0 + torch.cos(math.pi * x)) -# -# def warmup_constant(x, warmup=0.002): -# """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps. -# Learning rate is 1. afterwards. """ -# if x < warmup: -# return x/warmup -# return 1.0 -# -# def warmup_linear(x, warmup=0.002): -# """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step. -# After `t_total`-th training step, learning rate is zero. """ -# if x < warmup: -# return x/warmup -# return max((x-1.)/(warmup-1.), 0) -# -# SCHEDULES = { -# 'warmup_cosine': warmup_cosine, -# 'warmup_constant': warmup_constant, -# 'warmup_linear': warmup_linear, -# } + SCHEDULES = { None: LRSchedule, @@ -126,7 +104,9 @@ class BertAdam(Optimizer): warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 t_total: total number of training steps for the learning rate schedule, -1 means constant learning rate. Default: -1 - schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' + schedule: schedule to use for the warmup (see above). + Can be 'warmup_linear', 'warmup_constant', 'warmup_cosine', or a LRSchedule object. + Default: 'warmup_linear' b1: Adams b1. Default: 0.9 b2: Adams b2. Default: 0.999 e: Adams epsilon. Default: 1e-6 @@ -147,9 +127,13 @@ class BertAdam(Optimizer): if not e >= 0.0: raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) # initialize schedule object - schedule_type = SCHEDULES[schedule] - sched = schedule_type(warmup=warmup, t_total=t_total) - defaults = dict(lr=lr, schedule=sched, + if not isinstance(schedule, LRSchedule): + schedule_type = SCHEDULES[schedule] + schedule = schedule_type(warmup=warmup, t_total=t_total) + else: + if warmup != -1 or t_total != -1: + logger.warning("Non-default warmup and t_total are ineffective when LRSchedule object is provided.") + defaults = dict(lr=lr, schedule=schedule, b1=b1, b2=b2, e=e, weight_decay=weight_decay, max_grad_norm=max_grad_norm) super(BertAdam, self).__init__(params, defaults) @@ -163,7 +147,7 @@ class BertAdam(Optimizer): return [0] lr_scheduled = group['lr'] - lr_scheduled *= group['schedule'](state['step']) + lr_scheduled *= group['schedule'].get_lr(state['step']) lr.append(lr_scheduled) return lr @@ -221,7 +205,7 @@ class BertAdam(Optimizer): update += group['weight_decay'] * p.data lr_scheduled = group['lr'] - lr_scheduled *= group['schedule'](state['step']) + lr_scheduled *= group['schedule'].get_lr(state['step']) update_with_lr = lr_scheduled * update p.data.add_(-update_with_lr)