[pl_examples] revert deletion of optimizer_step (#5227)
This commit is contained in:
@@ -116,6 +116,19 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
self.opt = optimizer
|
self.opt = optimizer
|
||||||
return [optimizer]
|
return [optimizer]
|
||||||
|
|
||||||
|
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
|
||||||
|
if self.trainer.use_tpu:
|
||||||
|
xm.optimizer_step(optimizer)
|
||||||
|
else:
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
self.lr_scheduler.step()
|
||||||
|
|
||||||
|
def get_tqdm_dict(self):
|
||||||
|
avg_loss = getattr(self.trainer, "avg_loss", 0.0)
|
||||||
|
tqdm_dict = {"loss": "{:.3f}".format(avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]}
|
||||||
|
return tqdm_dict
|
||||||
|
|
||||||
def test_step(self, batch, batch_nb):
|
def test_step(self, batch, batch_nb):
|
||||||
return self.validation_step(batch, batch_nb)
|
return self.validation_step(batch, batch_nb)
|
||||||
|
|
||||||
|
|||||||
@@ -149,7 +149,7 @@ class SummarizationModule(BaseTransformer):
|
|||||||
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
|
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,)
|
generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,)
|
||||||
gen_time = time.time() - t0 / source_ids.shape[0]
|
gen_time = (time.time() - t0) / source_ids.shape[0]
|
||||||
preds = self.ids_to_clean_text(generated_ids)
|
preds = self.ids_to_clean_text(generated_ids)
|
||||||
target = self.ids_to_clean_text(y)
|
target = self.ids_to_clean_text(y)
|
||||||
loss_tensors = self._step(batch)
|
loss_tensors = self._step(batch)
|
||||||
|
|||||||
@@ -7,5 +7,6 @@ python distillation.py \
|
|||||||
--learning_rate=3e-4 \
|
--learning_rate=3e-4 \
|
||||||
--do_train \
|
--do_train \
|
||||||
--do_predict \
|
--do_predict \
|
||||||
|
--fp16 \
|
||||||
--val_check_interval 0.1 \
|
--val_check_interval 0.1 \
|
||||||
$@
|
$@
|
||||||
|
|||||||
Reference in New Issue
Block a user