From 68f7064a3ea979cdbdadfed62ad655eac4c53463 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Mon, 4 Nov 2019 11:52:35 -0500 Subject: [PATCH] Add `model.train()` line to ReadMe training example Co-Authored-By: Santosh-Gupta --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 6721163d16..40b08583b1 100644 --- a/README.md +++ b/README.md @@ -538,6 +538,7 @@ optimizer = AdamW(model.parameters(), lr=lr, correct_bias=False) # To reproduce scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_total=num_total_steps) # PyTorch scheduler ### and used like this: for batch in train_data: + model.train() loss = model(batch) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Gradient clipping is not in AdamW anymore (so you can use amp without issue)