[pl_examples] revert deletion of optimizer_step (#5227)

This commit is contained in:
Sam Shleifer
2020-06-23 16:40:45 -04:00
committed by GitHub
parent c01480bba3
commit 76e5af4cfd
3 changed files with 15 additions and 1 deletions

View File

@@ -116,6 +116,19 @@ class BaseTransformer(pl.LightningModule):
self.opt = optimizer
return [optimizer]
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
if self.trainer.use_tpu:
xm.optimizer_step(optimizer)
else:
optimizer.step()
optimizer.zero_grad()
self.lr_scheduler.step()
def get_tqdm_dict(self):
avg_loss = getattr(self.trainer, "avg_loss", 0.0)
tqdm_dict = {"loss": "{:.3f}".format(avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]}
return tqdm_dict
def test_step(self, batch, batch_nb):
return self.validation_step(batch, batch_nb)