fix tests
This commit is contained in:
@@ -33,8 +33,8 @@ from torch.utils.data import TensorDataset, DataLoader, RandomSampler, Sequentia
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from pytorch_pretrained_bert.tokenization import printable_text, whitespace_tokenize, BasicTokenizer, BertTokenizer
|
||||
from pytorch_pretrained_bert.modeling import BertConfig, BertForQuestionAnswering
|
||||
from pytorch_pretrained_bert.optimization import BERTAdam
|
||||
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
|
||||
from pytorch_pretrained_bert.optimization import BertAdam
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
@@ -847,7 +847,7 @@ def main():
|
||||
{'params': [p for n, p in param_optimizer if n not in no_decay], 'weight_decay_rate': 0.01},
|
||||
{'params': [p for n, p in param_optimizer if n in no_decay], 'weight_decay_rate': 0.0}
|
||||
]
|
||||
optimizer = BERTAdam(optimizer_grouped_parameters,
|
||||
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
warmup=args.warmup_proportion,
|
||||
t_total=num_train_steps)
|
||||
|
||||
Reference in New Issue
Block a user