From 76e5af4cfd821c0c610b9927a2d2cd58a02f43e4 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 23 Jun 2020 16:40:45 -0400 Subject: [PATCH] [pl_examples] revert deletion of optimizer_step (#5227) --- examples/lightning_base.py | 13 +++++++++++++ examples/summarization/finetune.py | 2 +- examples/summarization/run_distiller.sh | 1 + 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/examples/lightning_base.py b/examples/lightning_base.py index 2574aa9458..5a69a26e5b 100644 --- a/examples/lightning_base.py +++ b/examples/lightning_base.py @@ -116,6 +116,19 @@ class BaseTransformer(pl.LightningModule): self.opt = optimizer return [optimizer] + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None): + if self.trainer.use_tpu: + xm.optimizer_step(optimizer) + else: + optimizer.step() + optimizer.zero_grad() + self.lr_scheduler.step() + + def get_tqdm_dict(self): + avg_loss = getattr(self.trainer, "avg_loss", 0.0) + tqdm_dict = {"loss": "{:.3f}".format(avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]} + return tqdm_dict + def test_step(self, batch, batch_nb): return self.validation_step(batch, batch_nb) diff --git a/examples/summarization/finetune.py b/examples/summarization/finetune.py index f2e3f64637..a10d3f6511 100644 --- a/examples/summarization/finetune.py +++ b/examples/summarization/finetune.py @@ -149,7 +149,7 @@ class SummarizationModule(BaseTransformer): source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id) t0 = time.time() generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,) - gen_time = time.time() - t0 / source_ids.shape[0] + gen_time = (time.time() - t0) / source_ids.shape[0] preds = self.ids_to_clean_text(generated_ids) target = self.ids_to_clean_text(y) loss_tensors = self._step(batch) diff --git a/examples/summarization/run_distiller.sh b/examples/summarization/run_distiller.sh index a4d43de64a..6fbecad388 100755 --- a/examples/summarization/run_distiller.sh +++ b/examples/summarization/run_distiller.sh @@ -7,5 +7,6 @@ python distillation.py \ --learning_rate=3e-4 \ --do_train \ --do_predict \ +--fp16 \ --val_check_interval 0.1 \ $@