Lightning Updates for v0.8.5 (#5798)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -1,14 +1,11 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
|
||||
from transformers import (
|
||||
AdamW,
|
||||
@@ -42,14 +39,6 @@ MODEL_MODES = {
|
||||
}
|
||||
|
||||
|
||||
def set_seed(args: argparse.Namespace):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.gpus > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
class BaseTransformer(pl.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -63,7 +52,11 @@ class BaseTransformer(pl.LightningModule):
|
||||
):
|
||||
"""Initialize a model, tokenizer and config."""
|
||||
super().__init__()
|
||||
self.hparams = hparams # TODO: move to self.save_hyperparameters()
|
||||
# TODO: move to self.save_hyperparameters()
|
||||
# self.save_hyperparameters()
|
||||
# can also expand arguments into trainer signature for easier reading
|
||||
|
||||
self.hparams = hparams
|
||||
self.step_count = 0
|
||||
self.tfmr_ckpts = {}
|
||||
self.output_dir = Path(self.hparams.output_dir)
|
||||
@@ -114,17 +107,12 @@ class BaseTransformer(pl.LightningModule):
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
|
||||
self.opt = optimizer
|
||||
return [optimizer]
|
||||
|
||||
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
|
||||
if self.trainer.use_tpu:
|
||||
xm.optimizer_step(optimizer)
|
||||
else:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
self.lr_scheduler.step() # By default, PL will only step every epoch.
|
||||
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_lr())}
|
||||
self.logger.log_metrics(lrs)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps
|
||||
)
|
||||
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
def test_step(self, batch, batch_nb):
|
||||
return self.validation_step(batch, batch_nb)
|
||||
@@ -132,26 +120,24 @@ class BaseTransformer(pl.LightningModule):
|
||||
def test_epoch_end(self, outputs):
|
||||
return self.validation_end(outputs)
|
||||
|
||||
def train_dataloader(self):
|
||||
def setup(self, step):
|
||||
train_batch_size = self.hparams.train_batch_size
|
||||
dataloader = self.load_dataset("train", train_batch_size)
|
||||
dataloader = self.get_dataloader("train", train_batch_size)
|
||||
self.train_loader = dataloader
|
||||
self.total_steps = (
|
||||
(len(dataloader.dataset) // (train_batch_size * max(1, self.hparams.gpus)))
|
||||
// self.hparams.accumulate_grad_batches
|
||||
* float(self.hparams.max_epochs)
|
||||
)
|
||||
|
||||
t_total = (
|
||||
(len(dataloader.dataset) // (train_batch_size * max(1, self.hparams.n_gpu)))
|
||||
// self.hparams.gradient_accumulation_steps
|
||||
* float(self.hparams.num_train_epochs)
|
||||
)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
self.lr_scheduler = scheduler
|
||||
return dataloader
|
||||
def train_dataloader(self):
|
||||
return self.train_loader
|
||||
|
||||
def val_dataloader(self):
|
||||
return self.load_dataset("dev", self.hparams.eval_batch_size)
|
||||
return self.get_dataloader("dev", self.hparams.eval_batch_size)
|
||||
|
||||
def test_dataloader(self):
|
||||
return self.load_dataset("test", self.hparams.eval_batch_size)
|
||||
return self.get_dataloader("test", self.hparams.eval_batch_size)
|
||||
|
||||
def _feature_file(self, mode):
|
||||
return os.path.join(
|
||||
@@ -201,16 +187,16 @@ class BaseTransformer(pl.LightningModule):
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."
|
||||
)
|
||||
|
||||
parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int)
|
||||
parser.add_argument("--train_batch_size", default=32, type=int)
|
||||
parser.add_argument("--eval_batch_size", default=32, type=int)
|
||||
|
||||
|
||||
class LoggingCallback(pl.Callback):
|
||||
@rank_zero_only
|
||||
def on_batch_end(self, trainer, pl_module):
|
||||
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_lr())}
|
||||
pl_module.logger.log_metrics(lrs)
|
||||
|
||||
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
rank_zero_info("***** Validation results *****")
|
||||
metrics = trainer.callback_metrics
|
||||
@@ -219,16 +205,15 @@ class LoggingCallback(pl.Callback):
|
||||
if key not in ["log", "progress_bar"]:
|
||||
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
|
||||
|
||||
@rank_zero_only
|
||||
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
logger.info("***** Test results *****")
|
||||
rank_zero_info("***** Test results *****")
|
||||
metrics = trainer.callback_metrics
|
||||
# Log and save results to file
|
||||
output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
|
||||
with open(output_test_results_file, "w") as writer:
|
||||
for key in sorted(metrics):
|
||||
if key not in ["log", "progress_bar"]:
|
||||
logger.info("{} = {}\n".format(key, str(metrics[key])))
|
||||
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
|
||||
writer.write("{} = {}\n".format(key, str(metrics[key])))
|
||||
|
||||
|
||||
@@ -251,26 +236,23 @@ def add_generic_args(parser, root_dir) -> None:
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
default="O2",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
parser.add_argument("--fast_dev_run", action="store_true")
|
||||
parser.add_argument("--gpus", type=int, default=1)
|
||||
parser.add_argument("--n_tpu_cores", type=int, default=0)
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int, default=0)
|
||||
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.")
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
dest="accumulate_grad_batches",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
|
||||
parser.add_argument("--val_check_interval", default=1.0, type=float)
|
||||
|
||||
|
||||
def generic_train(
|
||||
@@ -283,10 +265,13 @@ def generic_train(
|
||||
logging_callback=None,
|
||||
**extra_train_kwargs
|
||||
):
|
||||
pl.seed_everything(args.seed)
|
||||
|
||||
# init model
|
||||
set_seed(args)
|
||||
odir = Path(model.hparams.output_dir)
|
||||
odir.mkdir(exist_ok=True)
|
||||
|
||||
# add custom checkpoints
|
||||
if checkpoint_callback is None:
|
||||
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
||||
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
|
||||
@@ -296,38 +281,25 @@ def generic_train(
|
||||
|
||||
train_params = {}
|
||||
|
||||
# TODO: remove with PyTorch 1.6 since pl uses native amp
|
||||
if args.fp16:
|
||||
train_params["use_amp"] = args.fp16
|
||||
train_params["precision"] = 16
|
||||
train_params["amp_level"] = args.fp16_opt_level
|
||||
|
||||
if args.n_tpu_cores > 0:
|
||||
global xm
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
train_params["num_tpu_cores"] = args.n_tpu_cores
|
||||
train_params["gpus"] = 0
|
||||
|
||||
if args.gpus > 1:
|
||||
train_params["distributed_backend"] = "ddp"
|
||||
|
||||
trainer = pl.Trainer(
|
||||
logger=logger,
|
||||
accumulate_grad_batches=args.gradient_accumulation_steps,
|
||||
gpus=args.gpus,
|
||||
max_epochs=args.num_train_epochs,
|
||||
early_stop_callback=early_stopping_callback,
|
||||
gradient_clip_val=args.max_grad_norm,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
callbacks=[logging_callback] + extra_callbacks,
|
||||
fast_dev_run=args.fast_dev_run,
|
||||
val_check_interval=args.val_check_interval,
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
weights_summary=None,
|
||||
resume_from_checkpoint=args.resume_from_checkpoint,
|
||||
callbacks=[logging_callback] + extra_callbacks,
|
||||
logger=logger,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
early_stop_callback=early_stopping_callback,
|
||||
**train_params,
|
||||
)
|
||||
|
||||
if args.do_train:
|
||||
trainer.fit(model)
|
||||
trainer.logger.log_hyperparams(args)
|
||||
trainer.logger.save()
|
||||
|
||||
return trainer
|
||||
|
||||
Reference in New Issue
Block a user