fix tests
This commit is contained in:
@@ -2,4 +2,4 @@ from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
|
||||
from .modeling import (BertConfig, BertModel, BertForPreTraining,
|
||||
BertForMaskedLM, BertForNextSentencePrediction,
|
||||
BertForSequenceClassification, BertForQuestionAnswering)
|
||||
from .optimization import BERTAdam
|
||||
from .optimization import BertAdam
|
||||
|
||||
@@ -41,7 +41,7 @@ SCHEDULES = {
|
||||
}
|
||||
|
||||
|
||||
class BERTAdam(Optimizer):
|
||||
class BertAdam(Optimizer):
|
||||
"""Implements BERT version of Adam algorithm with weight decay fix.
|
||||
Params:
|
||||
lr: learning rate
|
||||
@@ -73,7 +73,7 @@ class BERTAdam(Optimizer):
|
||||
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
|
||||
b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate,
|
||||
max_grad_norm=max_grad_norm)
|
||||
super(BERTAdam, self).__init__(params, defaults)
|
||||
super(BertAdam, self).__init__(params, defaults)
|
||||
|
||||
def get_lr(self):
|
||||
lr = []
|
||||
|
||||
Reference in New Issue
Block a user