[examples] SummarizationModule improvements (#4951)
This commit is contained in:
@@ -2,6 +2,8 @@ import argparse
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
@@ -13,10 +15,13 @@ from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForPreTraining,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelWithLMHead,
|
||||
AutoTokenizer,
|
||||
PretrainedConfig,
|
||||
PreTrainedTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
|
||||
@@ -31,6 +36,8 @@ MODEL_MODES = {
|
||||
"pretraining": AutoModelForPreTraining,
|
||||
"token-classification": AutoModelForTokenClassification,
|
||||
"language-modeling": AutoModelWithLMHead,
|
||||
"summarization": AutoModelForSeq2SeqLM,
|
||||
"translation": AutoModelForSeq2SeqLM,
|
||||
}
|
||||
|
||||
|
||||
@@ -38,33 +45,59 @@ def set_seed(args: argparse.Namespace):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
if args.gpus > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
class BaseTransformer(pl.LightningModule):
|
||||
def __init__(self, hparams: argparse.Namespace, num_labels=None, mode="base", **config_kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
hparams: argparse.Namespace,
|
||||
num_labels=None,
|
||||
mode="base",
|
||||
config=None,
|
||||
tokenizer=None,
|
||||
model=None,
|
||||
**config_kwargs
|
||||
):
|
||||
"Initialize a model."
|
||||
|
||||
super().__init__()
|
||||
self.hparams = hparams
|
||||
self.step_count = 0
|
||||
self.tfmr_ckpts = {}
|
||||
self.output_dir = Path(self.hparams.output_dir)
|
||||
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
|
||||
self.config = AutoConfig.from_pretrained(
|
||||
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
|
||||
**({"num_labels": num_labels} if num_labels is not None else {}),
|
||||
cache_dir=cache_dir,
|
||||
**config_kwargs,
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
self.model = MODEL_MODES[mode].from_pretrained(
|
||||
self.hparams.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
|
||||
config=self.config,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
if config is None:
|
||||
self.config = AutoConfig.from_pretrained(
|
||||
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
|
||||
**({"num_labels": num_labels} if num_labels is not None else {}),
|
||||
cache_dir=cache_dir,
|
||||
**config_kwargs,
|
||||
)
|
||||
else:
|
||||
self.config: PretrainedConfig = config
|
||||
if tokenizer is None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
else:
|
||||
self.tokenizer: PreTrainedTokenizer = tokenizer
|
||||
if model is None:
|
||||
self.model_type = MODEL_MODES[mode]
|
||||
self.model = self.model_type.from_pretrained(
|
||||
self.hparams.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
|
||||
config=self.config,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
else:
|
||||
self.model_type = None
|
||||
self.model = model
|
||||
|
||||
def load_hf_checkpoint(self, *args, **kwargs):
|
||||
self.model = self.model_type.from_pretrained(*args, **kwargs)
|
||||
|
||||
def is_logger(self):
|
||||
return self.trainer.proc_rank <= 0
|
||||
@@ -138,6 +171,15 @@ class BaseTransformer(pl.LightningModule):
|
||||
),
|
||||
)
|
||||
|
||||
@pl.utilities.rank_zero_only
|
||||
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
save_path = self.output_dir.joinpath("best_tfmr")
|
||||
save_path.mkdir(exist_ok=True)
|
||||
self.model.config.save_step = self.step_count
|
||||
self.model.save_pretrained(save_path)
|
||||
self.tokenizer.save_pretrained(save_path)
|
||||
self.tfmr_ckpts[self.step_count] = save_path
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
parser.add_argument(
|
||||
@@ -152,7 +194,7 @@ class BaseTransformer(pl.LightningModule):
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
@@ -165,7 +207,7 @@ class BaseTransformer(pl.LightningModule):
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument("--warmup_steps", default=500, type=int, help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."
|
||||
)
|
||||
@@ -199,7 +241,8 @@ class LoggingCallback(pl.Callback):
|
||||
writer.write("{} = {}\n".format(key, str(metrics[key])))
|
||||
|
||||
|
||||
def add_generic_args(parser, root_dir):
|
||||
def add_generic_args(parser, root_dir) -> None:
|
||||
# TODO(SS): allow all pl args? parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
@@ -221,8 +264,8 @@ def add_generic_args(parser, root_dir):
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
|
||||
parser.add_argument("--n_gpu", type=int, default=1)
|
||||
parser.add_argument("--fast_dev_run", action="store_true")
|
||||
parser.add_argument("--gpus", type=int, default=1)
|
||||
parser.add_argument("--n_tpu_cores", type=int, default=0)
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
@@ -235,28 +278,32 @@ def add_generic_args(parser, root_dir):
|
||||
)
|
||||
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
|
||||
parser.add_argument("--val_check_interval", default=1.0, type=float)
|
||||
|
||||
|
||||
def generic_train(model: BaseTransformer, args: argparse.Namespace):
|
||||
def generic_train(
|
||||
model: BaseTransformer,
|
||||
args: argparse.Namespace,
|
||||
early_stopping_callback=False,
|
||||
logger=True, # can pass WandbLogger() here
|
||||
extra_callbacks=[],
|
||||
checkpoint_callback=None,
|
||||
logging_callback=None,
|
||||
**extra_train_kwargs
|
||||
):
|
||||
# init model
|
||||
set_seed(args)
|
||||
odir = Path(model.hparams.output_dir)
|
||||
odir.mkdir(exist_ok=True)
|
||||
if checkpoint_callback is None:
|
||||
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
||||
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
|
||||
)
|
||||
if logging_callback is None:
|
||||
logging_callback = LoggingCallback()
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||
|
||||
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
||||
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=5
|
||||
)
|
||||
|
||||
train_params = dict(
|
||||
accumulate_grad_batches=args.gradient_accumulation_steps,
|
||||
gpus=args.n_gpu,
|
||||
max_epochs=args.num_train_epochs,
|
||||
early_stop_callback=False,
|
||||
gradient_clip_val=args.max_grad_norm,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
callbacks=[LoggingCallback()],
|
||||
)
|
||||
train_params = {}
|
||||
|
||||
if args.fp16:
|
||||
train_params["use_amp"] = args.fp16
|
||||
@@ -269,12 +316,27 @@ def generic_train(model: BaseTransformer, args: argparse.Namespace):
|
||||
train_params["num_tpu_cores"] = args.n_tpu_cores
|
||||
train_params["gpus"] = 0
|
||||
|
||||
if args.n_gpu > 1:
|
||||
if args.gpus > 1:
|
||||
train_params["distributed_backend"] = "ddp"
|
||||
|
||||
trainer = pl.Trainer(**train_params)
|
||||
trainer = pl.Trainer(
|
||||
logger=logger,
|
||||
accumulate_grad_batches=args.gradient_accumulation_steps,
|
||||
gpus=args.gpus,
|
||||
max_epochs=args.num_train_epochs,
|
||||
early_stop_callback=early_stopping_callback,
|
||||
gradient_clip_val=args.max_grad_norm,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
callbacks=[logging_callback] + extra_callbacks,
|
||||
fast_dev_run=args.fast_dev_run,
|
||||
val_check_interval=args.val_check_interval,
|
||||
weights_summary=None,
|
||||
resume_from_checkpoint=args.resume_from_checkpoint,
|
||||
**train_params,
|
||||
)
|
||||
|
||||
if args.do_train:
|
||||
trainer.fit(model)
|
||||
|
||||
trainer.logger.log_hyperparams(args)
|
||||
trainer.logger.save()
|
||||
return trainer
|
||||
|
||||
Reference in New Issue
Block a user