From 20e652209c7da7a73c9d1f3a65418d0ea118680e Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Wed, 13 Mar 2019 16:13:37 +0100 Subject: [PATCH] relation classification: replacing entity mention with mask token --- pytorch_pretrained_bert/optimization.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_pretrained_bert/optimization.py b/pytorch_pretrained_bert/optimization.py index 7eda3ba92a..9a873e221b 100644 --- a/pytorch_pretrained_bert/optimization.py +++ b/pytorch_pretrained_bert/optimization.py @@ -130,7 +130,7 @@ class BertAdam(Optimizer): max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 """ def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', - b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, + b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, init_weight_decay=0., max_grad_norm=1.0): if lr is not required and lr < 0.0: raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) @@ -150,7 +150,7 @@ class BertAdam(Optimizer): if warmup != -1 or t_total != -1: logger.warning("Non-default warmup and t_total are ineffective when LRSchedule object is provided.") defaults = dict(lr=lr, schedule=schedule, - b1=b1, b2=b2, e=e, weight_decay=weight_decay, + b1=b1, b2=b2, e=e, weight_decay=weight_decay, init_weight_decay=init_weight_decay, max_grad_norm=max_grad_norm) super(BertAdam, self).__init__(params, defaults) @@ -220,6 +220,8 @@ class BertAdam(Optimizer): if group['weight_decay'] > 0.0: update += group['weight_decay'] * p.data + # TODO: init weight decay + lr_scheduled = group['lr'] lr_scheduled *= group['schedule'].get_lr(state['step'])