fixing optimization

This commit is contained in:
thomwolf
2018-11-03 17:38:15 +01:00
parent 852e4b3c00
commit 088ad45888
4 changed files with 85 additions and 49 deletions

View File

@@ -483,10 +483,14 @@ def main():
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model.to(device)
optimizer = BERTAdam([{'params': [p for n, p in model.named_parameters() if n != 'bias'], 'l2': 0.01},
{'params': [p for n, p in model.named_parameters() if n == 'bias'], 'l2': 0.}
],
lr=args.learning_rate, schedule='warmup_linear',
no_decay = ['bias', 'gamma', 'beta']
optimizer_parameters = [
{'params': [p for n, p in model.named_parameters() if n not in no_decay], 'weight_decay_rate': 0.01},
{'params': [p for n, p in model.named_parameters() if n in no_decay], 'weight_decay_rate': 0.0}
]
optimizer = BERTAdam(optimizer_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=num_train_steps)