fixing optimization
This commit is contained in:
@@ -483,10 +483,14 @@ def main():
|
||||
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
||||
model.to(device)
|
||||
|
||||
optimizer = BERTAdam([{'params': [p for n, p in model.named_parameters() if n != 'bias'], 'l2': 0.01},
|
||||
{'params': [p for n, p in model.named_parameters() if n == 'bias'], 'l2': 0.}
|
||||
],
|
||||
lr=args.learning_rate, schedule='warmup_linear',
|
||||
no_decay = ['bias', 'gamma', 'beta']
|
||||
optimizer_parameters = [
|
||||
{'params': [p for n, p in model.named_parameters() if n not in no_decay], 'weight_decay_rate': 0.01},
|
||||
{'params': [p for n, p in model.named_parameters() if n in no_decay], 'weight_decay_rate': 0.0}
|
||||
]
|
||||
|
||||
optimizer = BERTAdam(optimizer_parameters,
|
||||
lr=args.learning_rate,
|
||||
warmup=args.warmup_proportion,
|
||||
t_total=num_train_steps)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user