fixes lr_scheduler warning

For more details, see https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
This commit is contained in:
Elijah Rippeth
2020-03-20 17:41:32 -04:00
committed by Julien Chaumond
parent 265709f5cd
commit 634bf6cf7e

View File

@@ -195,8 +195,8 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
scheduler.step() # Update learning rate schedule
optimizer.step()
scheduler.step() # Update learning rate schedule
model.zero_grad()
global_step += 1