Add model.train() line to ReadMe training example
Co-Authored-By: Santosh-Gupta <San.Gupta.ML@gmail.com>
This commit is contained in:
@@ -538,6 +538,7 @@ optimizer = AdamW(model.parameters(), lr=lr, correct_bias=False) # To reproduce
|
||||
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_total=num_total_steps) # PyTorch scheduler
|
||||
### and used like this:
|
||||
for batch in train_data:
|
||||
model.train()
|
||||
loss = model(batch)
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Gradient clipping is not in AdamW anymore (so you can use amp without issue)
|
||||
|
||||
Reference in New Issue
Block a user