relation classification: replacing entity mention with mask token
This commit is contained in:
@@ -130,7 +130,7 @@ class BertAdam(Optimizer):
|
|||||||
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
|
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
|
||||||
"""
|
"""
|
||||||
def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
|
def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
|
||||||
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01,
|
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, init_weight_decay=0.,
|
||||||
max_grad_norm=1.0):
|
max_grad_norm=1.0):
|
||||||
if lr is not required and lr < 0.0:
|
if lr is not required and lr < 0.0:
|
||||||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
||||||
@@ -150,7 +150,7 @@ class BertAdam(Optimizer):
|
|||||||
if warmup != -1 or t_total != -1:
|
if warmup != -1 or t_total != -1:
|
||||||
logger.warning("Non-default warmup and t_total are ineffective when LRSchedule object is provided.")
|
logger.warning("Non-default warmup and t_total are ineffective when LRSchedule object is provided.")
|
||||||
defaults = dict(lr=lr, schedule=schedule,
|
defaults = dict(lr=lr, schedule=schedule,
|
||||||
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
|
b1=b1, b2=b2, e=e, weight_decay=weight_decay, init_weight_decay=init_weight_decay,
|
||||||
max_grad_norm=max_grad_norm)
|
max_grad_norm=max_grad_norm)
|
||||||
super(BertAdam, self).__init__(params, defaults)
|
super(BertAdam, self).__init__(params, defaults)
|
||||||
|
|
||||||
@@ -220,6 +220,8 @@ class BertAdam(Optimizer):
|
|||||||
if group['weight_decay'] > 0.0:
|
if group['weight_decay'] > 0.0:
|
||||||
update += group['weight_decay'] * p.data
|
update += group['weight_decay'] * p.data
|
||||||
|
|
||||||
|
# TODO: init weight decay
|
||||||
|
|
||||||
lr_scheduled = group['lr']
|
lr_scheduled = group['lr']
|
||||||
lr_scheduled *= group['schedule'].get_lr(state['step'])
|
lr_scheduled *= group['schedule'].get_lr(state['step'])
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user