relation classification: replacing entity mention with mask token
This commit is contained in:
@@ -130,7 +130,7 @@ class BertAdam(Optimizer):
|
||||
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
|
||||
"""
|
||||
def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
|
||||
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01,
|
||||
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, init_weight_decay=0.,
|
||||
max_grad_norm=1.0):
|
||||
if lr is not required and lr < 0.0:
|
||||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
||||
@@ -150,7 +150,7 @@ class BertAdam(Optimizer):
|
||||
if warmup != -1 or t_total != -1:
|
||||
logger.warning("Non-default warmup and t_total are ineffective when LRSchedule object is provided.")
|
||||
defaults = dict(lr=lr, schedule=schedule,
|
||||
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
|
||||
b1=b1, b2=b2, e=e, weight_decay=weight_decay, init_weight_decay=init_weight_decay,
|
||||
max_grad_norm=max_grad_norm)
|
||||
super(BertAdam, self).__init__(params, defaults)
|
||||
|
||||
@@ -220,6 +220,8 @@ class BertAdam(Optimizer):
|
||||
if group['weight_decay'] > 0.0:
|
||||
update += group['weight_decay'] * p.data
|
||||
|
||||
# TODO: init weight decay
|
||||
|
||||
lr_scheduled = group['lr']
|
||||
lr_scheduled *= group['schedule'].get_lr(state['step'])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user