fix tests

This commit is contained in:
thomwolf
2018-11-17 11:58:14 +01:00
parent a99b971738
commit 757750d6f6
9 changed files with 48 additions and 61 deletions

View File

@@ -2,4 +2,4 @@ from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
from .modeling import (BertConfig, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction,
BertForSequenceClassification, BertForQuestionAnswering)
from .optimization import BERTAdam
from .optimization import BertAdam

View File

@@ -41,7 +41,7 @@ SCHEDULES = {
}
class BERTAdam(Optimizer):
class BertAdam(Optimizer):
"""Implements BERT version of Adam algorithm with weight decay fix.
Params:
lr: learning rate
@@ -73,7 +73,7 @@ class BERTAdam(Optimizer):
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate,
max_grad_norm=max_grad_norm)
super(BERTAdam, self).__init__(params, defaults)
super(BertAdam, self).__init__(params, defaults)
def get_lr(self):
lr = []