Lightning Updates for v0.8.5 (#5798)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -1,14 +1,11 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
from pytorch_lightning.utilities import rank_zero_info
|
||||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AdamW,
|
AdamW,
|
||||||
@@ -42,14 +39,6 @@ MODEL_MODES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def set_seed(args: argparse.Namespace):
|
|
||||||
random.seed(args.seed)
|
|
||||||
np.random.seed(args.seed)
|
|
||||||
torch.manual_seed(args.seed)
|
|
||||||
if args.gpus > 0:
|
|
||||||
torch.cuda.manual_seed_all(args.seed)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseTransformer(pl.LightningModule):
|
class BaseTransformer(pl.LightningModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -63,7 +52,11 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
):
|
):
|
||||||
"""Initialize a model, tokenizer and config."""
|
"""Initialize a model, tokenizer and config."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hparams = hparams # TODO: move to self.save_hyperparameters()
|
# TODO: move to self.save_hyperparameters()
|
||||||
|
# self.save_hyperparameters()
|
||||||
|
# can also expand arguments into trainer signature for easier reading
|
||||||
|
|
||||||
|
self.hparams = hparams
|
||||||
self.step_count = 0
|
self.step_count = 0
|
||||||
self.tfmr_ckpts = {}
|
self.tfmr_ckpts = {}
|
||||||
self.output_dir = Path(self.hparams.output_dir)
|
self.output_dir = Path(self.hparams.output_dir)
|
||||||
@@ -114,17 +107,12 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
]
|
]
|
||||||
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
|
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
|
||||||
self.opt = optimizer
|
self.opt = optimizer
|
||||||
return [optimizer]
|
|
||||||
|
|
||||||
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
|
scheduler = get_linear_schedule_with_warmup(
|
||||||
if self.trainer.use_tpu:
|
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps
|
||||||
xm.optimizer_step(optimizer)
|
)
|
||||||
else:
|
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
|
||||||
optimizer.step()
|
return [optimizer], [scheduler]
|
||||||
optimizer.zero_grad()
|
|
||||||
self.lr_scheduler.step() # By default, PL will only step every epoch.
|
|
||||||
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_lr())}
|
|
||||||
self.logger.log_metrics(lrs)
|
|
||||||
|
|
||||||
def test_step(self, batch, batch_nb):
|
def test_step(self, batch, batch_nb):
|
||||||
return self.validation_step(batch, batch_nb)
|
return self.validation_step(batch, batch_nb)
|
||||||
@@ -132,26 +120,24 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
def test_epoch_end(self, outputs):
|
def test_epoch_end(self, outputs):
|
||||||
return self.validation_end(outputs)
|
return self.validation_end(outputs)
|
||||||
|
|
||||||
def train_dataloader(self):
|
def setup(self, step):
|
||||||
train_batch_size = self.hparams.train_batch_size
|
train_batch_size = self.hparams.train_batch_size
|
||||||
dataloader = self.load_dataset("train", train_batch_size)
|
dataloader = self.get_dataloader("train", train_batch_size)
|
||||||
|
self.train_loader = dataloader
|
||||||
|
self.total_steps = (
|
||||||
|
(len(dataloader.dataset) // (train_batch_size * max(1, self.hparams.gpus)))
|
||||||
|
// self.hparams.accumulate_grad_batches
|
||||||
|
* float(self.hparams.max_epochs)
|
||||||
|
)
|
||||||
|
|
||||||
t_total = (
|
def train_dataloader(self):
|
||||||
(len(dataloader.dataset) // (train_batch_size * max(1, self.hparams.n_gpu)))
|
return self.train_loader
|
||||||
// self.hparams.gradient_accumulation_steps
|
|
||||||
* float(self.hparams.num_train_epochs)
|
|
||||||
)
|
|
||||||
scheduler = get_linear_schedule_with_warmup(
|
|
||||||
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
|
|
||||||
)
|
|
||||||
self.lr_scheduler = scheduler
|
|
||||||
return dataloader
|
|
||||||
|
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
return self.load_dataset("dev", self.hparams.eval_batch_size)
|
return self.get_dataloader("dev", self.hparams.eval_batch_size)
|
||||||
|
|
||||||
def test_dataloader(self):
|
def test_dataloader(self):
|
||||||
return self.load_dataset("test", self.hparams.eval_batch_size)
|
return self.get_dataloader("test", self.hparams.eval_batch_size)
|
||||||
|
|
||||||
def _feature_file(self, mode):
|
def _feature_file(self, mode):
|
||||||
return os.path.join(
|
return os.path.join(
|
||||||
@@ -201,16 +187,16 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||||
parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
|
parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
|
||||||
parser.add_argument(
|
parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int)
|
||||||
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument("--train_batch_size", default=32, type=int)
|
parser.add_argument("--train_batch_size", default=32, type=int)
|
||||||
parser.add_argument("--eval_batch_size", default=32, type=int)
|
parser.add_argument("--eval_batch_size", default=32, type=int)
|
||||||
|
|
||||||
|
|
||||||
class LoggingCallback(pl.Callback):
|
class LoggingCallback(pl.Callback):
|
||||||
@rank_zero_only
|
def on_batch_end(self, trainer, pl_module):
|
||||||
|
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_lr())}
|
||||||
|
pl_module.logger.log_metrics(lrs)
|
||||||
|
|
||||||
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||||
rank_zero_info("***** Validation results *****")
|
rank_zero_info("***** Validation results *****")
|
||||||
metrics = trainer.callback_metrics
|
metrics = trainer.callback_metrics
|
||||||
@@ -219,16 +205,15 @@ class LoggingCallback(pl.Callback):
|
|||||||
if key not in ["log", "progress_bar"]:
|
if key not in ["log", "progress_bar"]:
|
||||||
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
|
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
|
||||||
|
|
||||||
@rank_zero_only
|
|
||||||
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||||
logger.info("***** Test results *****")
|
rank_zero_info("***** Test results *****")
|
||||||
metrics = trainer.callback_metrics
|
metrics = trainer.callback_metrics
|
||||||
# Log and save results to file
|
# Log and save results to file
|
||||||
output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
|
output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
|
||||||
with open(output_test_results_file, "w") as writer:
|
with open(output_test_results_file, "w") as writer:
|
||||||
for key in sorted(metrics):
|
for key in sorted(metrics):
|
||||||
if key not in ["log", "progress_bar"]:
|
if key not in ["log", "progress_bar"]:
|
||||||
logger.info("{} = {}\n".format(key, str(metrics[key])))
|
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
|
||||||
writer.write("{} = {}\n".format(key, str(metrics[key])))
|
writer.write("{} = {}\n".format(key, str(metrics[key])))
|
||||||
|
|
||||||
|
|
||||||
@@ -251,26 +236,23 @@ def add_generic_args(parser, root_dir) -> None:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--fp16_opt_level",
|
"--fp16_opt_level",
|
||||||
type=str,
|
type=str,
|
||||||
default="O1",
|
default="O2",
|
||||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
"See details at https://nvidia.github.io/apex/amp.html",
|
"See details at https://nvidia.github.io/apex/amp.html",
|
||||||
)
|
)
|
||||||
parser.add_argument("--fast_dev_run", action="store_true")
|
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int, default=0)
|
||||||
parser.add_argument("--gpus", type=int, default=1)
|
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
|
||||||
parser.add_argument("--n_tpu_cores", type=int, default=0)
|
|
||||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
|
||||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||||
parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.")
|
parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--gradient_accumulation_steps",
|
"--gradient_accumulation_steps",
|
||||||
|
dest="accumulate_grad_batches",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||||
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
|
|
||||||
parser.add_argument("--val_check_interval", default=1.0, type=float)
|
|
||||||
|
|
||||||
|
|
||||||
def generic_train(
|
def generic_train(
|
||||||
@@ -283,10 +265,13 @@ def generic_train(
|
|||||||
logging_callback=None,
|
logging_callback=None,
|
||||||
**extra_train_kwargs
|
**extra_train_kwargs
|
||||||
):
|
):
|
||||||
|
pl.seed_everything(args.seed)
|
||||||
|
|
||||||
# init model
|
# init model
|
||||||
set_seed(args)
|
|
||||||
odir = Path(model.hparams.output_dir)
|
odir = Path(model.hparams.output_dir)
|
||||||
odir.mkdir(exist_ok=True)
|
odir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# add custom checkpoints
|
||||||
if checkpoint_callback is None:
|
if checkpoint_callback is None:
|
||||||
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
||||||
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
|
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
|
||||||
@@ -296,38 +281,25 @@ def generic_train(
|
|||||||
|
|
||||||
train_params = {}
|
train_params = {}
|
||||||
|
|
||||||
|
# TODO: remove with PyTorch 1.6 since pl uses native amp
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
train_params["use_amp"] = args.fp16
|
train_params["precision"] = 16
|
||||||
train_params["amp_level"] = args.fp16_opt_level
|
train_params["amp_level"] = args.fp16_opt_level
|
||||||
|
|
||||||
if args.n_tpu_cores > 0:
|
|
||||||
global xm
|
|
||||||
import torch_xla.core.xla_model as xm
|
|
||||||
|
|
||||||
train_params["num_tpu_cores"] = args.n_tpu_cores
|
|
||||||
train_params["gpus"] = 0
|
|
||||||
|
|
||||||
if args.gpus > 1:
|
if args.gpus > 1:
|
||||||
train_params["distributed_backend"] = "ddp"
|
train_params["distributed_backend"] = "ddp"
|
||||||
|
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
logger=logger,
|
args,
|
||||||
accumulate_grad_batches=args.gradient_accumulation_steps,
|
|
||||||
gpus=args.gpus,
|
|
||||||
max_epochs=args.num_train_epochs,
|
|
||||||
early_stop_callback=early_stopping_callback,
|
|
||||||
gradient_clip_val=args.max_grad_norm,
|
|
||||||
checkpoint_callback=checkpoint_callback,
|
|
||||||
callbacks=[logging_callback] + extra_callbacks,
|
|
||||||
fast_dev_run=args.fast_dev_run,
|
|
||||||
val_check_interval=args.val_check_interval,
|
|
||||||
weights_summary=None,
|
weights_summary=None,
|
||||||
resume_from_checkpoint=args.resume_from_checkpoint,
|
callbacks=[logging_callback] + extra_callbacks,
|
||||||
|
logger=logger,
|
||||||
|
checkpoint_callback=checkpoint_callback,
|
||||||
|
early_stop_callback=early_stopping_callback,
|
||||||
**train_params,
|
**train_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
trainer.logger.log_hyperparams(args)
|
|
||||||
trainer.logger.save()
|
|
||||||
return trainer
|
return trainer
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ psutil
|
|||||||
sacrebleu
|
sacrebleu
|
||||||
rouge-score
|
rouge-score
|
||||||
tensorflow_datasets
|
tensorflow_datasets
|
||||||
pytorch-lightning==0.8.1
|
pytorch-lightning==0.8.5
|
||||||
matplotlib
|
matplotlib
|
||||||
git-python==1.0.3
|
git-python==1.0.3
|
||||||
faiss
|
faiss
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ Summarization Tips:
|
|||||||
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
|
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
|
||||||
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
|
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
|
||||||
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
|
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
|
||||||
- `wandb` can be used by specifying `--logger wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task.
|
- `wandb` can be used by specifying `--logger_name wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task.
|
||||||
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
|
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
|
||||||
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
|
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
|
||||||
|
|
||||||
@@ -124,7 +124,7 @@ model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr')
|
|||||||
```
|
```
|
||||||
|
|
||||||
#### XSUM Shared Task
|
#### XSUM Shared Task
|
||||||
Compare XSUM results with others by using `--logger wandb_shared`. This requires `wandb` registration.
|
Compare XSUM results with others by using `--logger_name wandb_shared`. This requires `wandb` registration.
|
||||||
|
|
||||||
Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier!
|
Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier!
|
||||||
```bash
|
```bash
|
||||||
@@ -135,7 +135,7 @@ WANDB_PROJECT='hf_xsum' ./finetune.sh \
|
|||||||
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
|
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
|
||||||
--num_train_epochs 6 \
|
--num_train_epochs 6 \
|
||||||
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
|
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
|
||||||
--logger wandb
|
--logger_name wandb
|
||||||
```
|
```
|
||||||
|
|
||||||
You can see your wandb logs [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-)
|
You can see your wandb logs [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-)
|
||||||
|
|||||||
@@ -221,8 +221,8 @@ class SummarizationModule(BaseTransformer):
|
|||||||
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
||||||
t_total = (
|
t_total = (
|
||||||
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
|
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
|
||||||
// self.hparams.gradient_accumulation_steps
|
// self.hparams.accumulate_grad_batches
|
||||||
* float(self.hparams.num_train_epochs)
|
* float(self.hparams.max_epochs)
|
||||||
)
|
)
|
||||||
scheduler = get_linear_schedule_with_warmup(
|
scheduler = get_linear_schedule_with_warmup(
|
||||||
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
|
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
|
||||||
@@ -279,7 +279,7 @@ class SummarizationModule(BaseTransformer):
|
|||||||
parser.add_argument("--freeze_encoder", action="store_true")
|
parser.add_argument("--freeze_encoder", action="store_true")
|
||||||
parser.add_argument("--freeze_embeds", action="store_true")
|
parser.add_argument("--freeze_embeds", action="store_true")
|
||||||
parser.add_argument("--sortish_sampler", action="store_true", default=False)
|
parser.add_argument("--sortish_sampler", action="store_true", default=False)
|
||||||
parser.add_argument("--logger", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
||||||
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||||
parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
|
parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
|
||||||
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||||
@@ -288,7 +288,6 @@ class SummarizationModule(BaseTransformer):
|
|||||||
)
|
)
|
||||||
parser.add_argument("--src_lang", type=str, default="", required=False)
|
parser.add_argument("--src_lang", type=str, default="", required=False)
|
||||||
parser.add_argument("--tgt_lang", type=str, default="", required=False)
|
parser.add_argument("--tgt_lang", type=str, default="", required=False)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@@ -318,22 +317,24 @@ def main(args, model=None) -> SummarizationModule:
|
|||||||
model: SummarizationModule = SummarizationModule(args)
|
model: SummarizationModule = SummarizationModule(args)
|
||||||
else:
|
else:
|
||||||
model: SummarizationModule = TranslationModule(args)
|
model: SummarizationModule = TranslationModule(args)
|
||||||
|
|
||||||
|
dataset = Path(args.data_dir).name
|
||||||
if (
|
if (
|
||||||
args.logger == "default"
|
args.logger_name == "default"
|
||||||
or args.fast_dev_run
|
or args.fast_dev_run
|
||||||
or str(args.output_dir).startswith("/tmp")
|
or str(args.output_dir).startswith("/tmp")
|
||||||
or str(args.output_dir).startswith("/var")
|
or str(args.output_dir).startswith("/var")
|
||||||
):
|
):
|
||||||
logger = True # don't pollute wandb logs unnecessarily
|
logger = True # don't pollute wandb logs unnecessarily
|
||||||
elif args.logger == "wandb":
|
elif args.logger_name == "wandb":
|
||||||
from pytorch_lightning.loggers import WandbLogger
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
|
|
||||||
logger = WandbLogger(name=model.output_dir.name)
|
logger = WandbLogger(name=model.output_dir.name, project=dataset)
|
||||||
|
|
||||||
elif args.logger == "wandb_shared":
|
elif args.logger_name == "wandb_shared":
|
||||||
from pytorch_lightning.loggers import WandbLogger
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
|
|
||||||
logger = WandbLogger(name=model.output_dir.name)
|
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
||||||
trainer: pl.Trainer = generic_train(
|
trainer: pl.Trainer = generic_train(
|
||||||
model,
|
model,
|
||||||
args,
|
args,
|
||||||
@@ -352,13 +353,17 @@ def main(args, model=None) -> SummarizationModule:
|
|||||||
model.hparams.test_checkpoint = checkpoints[-1]
|
model.hparams.test_checkpoint = checkpoints[-1]
|
||||||
trainer.resume_from_checkpoint = checkpoints[-1]
|
trainer.resume_from_checkpoint = checkpoints[-1]
|
||||||
trainer.logger.log_hyperparams(model.hparams)
|
trainer.logger.log_hyperparams(model.hparams)
|
||||||
trainer.test(model) # this breaks in DDP, known lightning issue. See evaluate_checkpoint to recover metrics.
|
|
||||||
|
# test() without a model tests using the best checkpoint automatically
|
||||||
|
trainer.test()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
|
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@@ -10,5 +10,4 @@ python finetune.py \
|
|||||||
--do_predict \
|
--do_predict \
|
||||||
--n_val 1000 \
|
--n_val 1000 \
|
||||||
--val_check_interval 0.1 \
|
--val_check_interval 0.1 \
|
||||||
--sortish_sampler \
|
|
||||||
$@
|
$@
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ logging.basicConfig(level=logging.DEBUG)
|
|||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
CUDA_AVAILABLE = torch.cuda.is_available()
|
CUDA_AVAILABLE = torch.cuda.is_available()
|
||||||
CHEAP_ARGS = {
|
CHEAP_ARGS = {
|
||||||
"logger": "default",
|
"logger_name": "default",
|
||||||
"length_penalty": 0.5,
|
"length_penalty": 0.5,
|
||||||
"cache_dir": "",
|
"cache_dir": "",
|
||||||
"task": "summarization",
|
"task": "summarization",
|
||||||
@@ -48,7 +48,7 @@ CHEAP_ARGS = {
|
|||||||
"max_grad_norm": 1.0,
|
"max_grad_norm": 1.0,
|
||||||
"do_train": True,
|
"do_train": True,
|
||||||
"do_predict": True,
|
"do_predict": True,
|
||||||
"gradient_accumulation_steps": 1,
|
"accumulate_grad_batches": 1,
|
||||||
"server_ip": "",
|
"server_ip": "",
|
||||||
"server_port": "",
|
"server_port": "",
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
@@ -60,7 +60,7 @@ CHEAP_ARGS = {
|
|||||||
"weight_decay": 0.0,
|
"weight_decay": 0.0,
|
||||||
"adam_epsilon": 1e-08,
|
"adam_epsilon": 1e-08,
|
||||||
"warmup_steps": 0,
|
"warmup_steps": 0,
|
||||||
"num_train_epochs": 1,
|
"max_epochs": 1,
|
||||||
"train_batch_size": 2,
|
"train_batch_size": 2,
|
||||||
"eval_batch_size": 2,
|
"eval_batch_size": 2,
|
||||||
"max_source_length": 12,
|
"max_source_length": 12,
|
||||||
@@ -122,7 +122,7 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
updates = dict(
|
updates = dict(
|
||||||
student_encoder_layers=2,
|
student_encoder_layers=2,
|
||||||
student_decoder_layers=1,
|
student_decoder_layers=1,
|
||||||
num_train_epochs=4,
|
max_epochs=4,
|
||||||
val_check_interval=0.25,
|
val_check_interval=0.25,
|
||||||
alpha_hid=2.0,
|
alpha_hid=2.0,
|
||||||
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
|
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
|
||||||
@@ -156,7 +156,7 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
default_updates = dict(
|
default_updates = dict(
|
||||||
train_batch_size=1,
|
train_batch_size=1,
|
||||||
eval_batch_size=2,
|
eval_batch_size=2,
|
||||||
num_train_epochs=2,
|
max_epochs=2,
|
||||||
alpha_mlm=0.2,
|
alpha_mlm=0.2,
|
||||||
alpha_ce=0.8,
|
alpha_ce=0.8,
|
||||||
do_predict=True,
|
do_predict=True,
|
||||||
@@ -187,7 +187,7 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
|
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
|
||||||
self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"])
|
self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"])
|
||||||
self.assertIsInstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
|
self.assertIsInstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
|
||||||
desired_n_evals = int(args_d["num_train_epochs"] * (1 / args_d["val_check_interval"]) + 1)
|
desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) + 1)
|
||||||
self.assertEqual(len(metrics["val"]), desired_n_evals)
|
self.assertEqual(len(metrics["val"]), desired_n_evals)
|
||||||
self.assertEqual(len(metrics["test"]), 1)
|
self.assertEqual(len(metrics["test"]), 1)
|
||||||
return model
|
return model
|
||||||
|
|||||||
@@ -17,5 +17,5 @@ python finetune.py \
|
|||||||
--model_name_or_path facebook/mbart-large-cc25 \
|
--model_name_or_path facebook/mbart-large-cc25 \
|
||||||
--task translation \
|
--task translation \
|
||||||
--warmup_steps 500 \
|
--warmup_steps 500 \
|
||||||
--logger wandb --sortish_sampler \
|
--logger_name wandb --sortish_sampler \
|
||||||
$@
|
$@
|
||||||
|
|||||||
Reference in New Issue
Block a user