diff --git a/README.md b/README.md index 6721163d16..40b08583b1 100644 --- a/README.md +++ b/README.md @@ -538,6 +538,7 @@ optimizer = AdamW(model.parameters(), lr=lr, correct_bias=False) # To reproduce scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_total=num_total_steps) # PyTorch scheduler ### and used like this: for batch in train_data: + model.train() loss = model(batch) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Gradient clipping is not in AdamW anymore (so you can use amp without issue)