Upgrade examples to pl=0.8.1(#5146)

This commit is contained in:
Sam Shleifer
2020-06-22 20:40:10 -04:00
committed by GitHub
parent 06b60c8b05
commit f5c2a122e3
11 changed files with 53 additions and 150 deletions

View File

@@ -8,6 +8,7 @@ from typing import Any, Dict
import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from transformers import (
AdamW,
@@ -60,10 +61,9 @@ class BaseTransformer(pl.LightningModule):
model=None,
**config_kwargs
):
"Initialize a model."
"""Initialize a model, tokenizer and config."""
super().__init__()
self.hparams = hparams
self.hparams = hparams # TODO: move to self.save_hyperparameters()
self.step_count = 0
self.tfmr_ckpts = {}
self.output_dir = Path(self.hparams.output_dir)
@@ -84,8 +84,8 @@ class BaseTransformer(pl.LightningModule):
)
else:
self.tokenizer: PreTrainedTokenizer = tokenizer
self.model_type = MODEL_MODES[mode]
if model is None:
self.model_type = MODEL_MODES[mode]
self.model = self.model_type.from_pretrained(
self.hparams.model_name_or_path,
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
@@ -93,18 +93,13 @@ class BaseTransformer(pl.LightningModule):
cache_dir=cache_dir,
)
else:
self.model_type = None
self.model = model
def load_hf_checkpoint(self, *args, **kwargs):
self.model = self.model_type.from_pretrained(*args, **kwargs)
def is_logger(self):
return self.trainer.proc_rank <= 0
def configure_optimizers(self):
"Prepare optimizer and schedule (linear warmup and decay)"
model = self.model
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
@@ -121,23 +116,10 @@ class BaseTransformer(pl.LightningModule):
self.opt = optimizer
return [optimizer]
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
if self.trainer.use_tpu:
xm.optimizer_step(optimizer)
else:
optimizer.step()
optimizer.zero_grad()
self.lr_scheduler.step()
def get_tqdm_dict(self):
avg_loss = getattr(self.trainer, "avg_loss", 0.0)
tqdm_dict = {"loss": "{:.3f}".format(avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]}
return tqdm_dict
def test_step(self, batch, batch_nb):
return self.validation_step(batch, batch_nb)
def test_end(self, outputs):
def test_epoch_end(self, outputs):
return self.validation_end(outputs)
def train_dataloader(self):
@@ -208,6 +190,7 @@ class BaseTransformer(pl.LightningModule):
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--warmup_steps", default=500, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
parser.add_argument(
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."
)
@@ -217,28 +200,26 @@ class BaseTransformer(pl.LightningModule):
class LoggingCallback(pl.Callback):
@rank_zero_only
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
logger.info("***** Validation results *****")
if pl_module.is_logger():
metrics = trainer.callback_metrics
# Log results
rank_zero_info("***** Validation results *****")
metrics = trainer.callback_metrics
# Log results
for key in sorted(metrics):
if key not in ["log", "progress_bar"]:
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
@rank_zero_only
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
logger.info("***** Test results *****")
metrics = trainer.callback_metrics
# Log and save results to file
output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
with open(output_test_results_file, "w") as writer:
for key in sorted(metrics):
if key not in ["log", "progress_bar"]:
logger.info("{} = {}\n".format(key, str(metrics[key])))
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
logger.info("***** Test results *****")
if pl_module.is_logger():
metrics = trainer.callback_metrics
# Log and save results to file
output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
with open(output_test_results_file, "w") as writer:
for key in sorted(metrics):
if key not in ["log", "progress_bar"]:
logger.info("{} = {}\n".format(key, str(metrics[key])))
writer.write("{} = {}\n".format(key, str(metrics[key])))
writer.write("{} = {}\n".format(key, str(metrics[key])))
def add_generic_args(parser, root_dir) -> None: