Lightning Updates for v0.8.5 (#5798)

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Nathan Raw
2020-07-17 20:43:06 -06:00
committed by GitHub
parent 615be03f9d
commit 529850ae7b
7 changed files with 73 additions and 97 deletions

View File

@@ -60,7 +60,7 @@ Summarization Tips:
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
- `wandb` can be used by specifying `--logger wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task.
- `wandb` can be used by specifying `--logger_name wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task.
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
@@ -124,7 +124,7 @@ model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr')
```
#### XSUM Shared Task
Compare XSUM results with others by using `--logger wandb_shared`. This requires `wandb` registration.
Compare XSUM results with others by using `--logger_name wandb_shared`. This requires `wandb` registration.
Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier!
```bash
@@ -135,7 +135,7 @@ WANDB_PROJECT='hf_xsum' ./finetune.sh \
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
--num_train_epochs 6 \
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
--logger wandb
--logger_name wandb
```
You can see your wandb logs [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-)

View File

@@ -221,8 +221,8 @@ class SummarizationModule(BaseTransformer):
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
t_total = (
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
// self.hparams.gradient_accumulation_steps
* float(self.hparams.num_train_epochs)
// self.hparams.accumulate_grad_batches
* float(self.hparams.max_epochs)
)
scheduler = get_linear_schedule_with_warmup(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
@@ -279,7 +279,7 @@ class SummarizationModule(BaseTransformer):
parser.add_argument("--freeze_encoder", action="store_true")
parser.add_argument("--freeze_embeds", action="store_true")
parser.add_argument("--sortish_sampler", action="store_true", default=False)
parser.add_argument("--logger", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
@@ -288,7 +288,6 @@ class SummarizationModule(BaseTransformer):
)
parser.add_argument("--src_lang", type=str, default="", required=False)
parser.add_argument("--tgt_lang", type=str, default="", required=False)
return parser
@@ -318,22 +317,24 @@ def main(args, model=None) -> SummarizationModule:
model: SummarizationModule = SummarizationModule(args)
else:
model: SummarizationModule = TranslationModule(args)
dataset = Path(args.data_dir).name
if (
args.logger == "default"
args.logger_name == "default"
or args.fast_dev_run
or str(args.output_dir).startswith("/tmp")
or str(args.output_dir).startswith("/var")
):
logger = True # don't pollute wandb logs unnecessarily
elif args.logger == "wandb":
elif args.logger_name == "wandb":
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(name=model.output_dir.name)
logger = WandbLogger(name=model.output_dir.name, project=dataset)
elif args.logger == "wandb_shared":
elif args.logger_name == "wandb_shared":
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(name=model.output_dir.name)
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
trainer: pl.Trainer = generic_train(
model,
args,
@@ -352,13 +353,17 @@ def main(args, model=None) -> SummarizationModule:
model.hparams.test_checkpoint = checkpoints[-1]
trainer.resume_from_checkpoint = checkpoints[-1]
trainer.logger.log_hyperparams(model.hparams)
trainer.test(model) # this breaks in DDP, known lightning issue. See evaluate_checkpoint to recover metrics.
# test() without a model tests using the best checkpoint automatically
trainer.test()
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
main(args)

View File

@@ -10,5 +10,4 @@ python finetune.py \
--do_predict \
--n_val 1000 \
--val_check_interval 0.1 \
--sortish_sampler \
$@

View File

@@ -26,7 +26,7 @@ logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
CUDA_AVAILABLE = torch.cuda.is_available()
CHEAP_ARGS = {
"logger": "default",
"logger_name": "default",
"length_penalty": 0.5,
"cache_dir": "",
"task": "summarization",
@@ -48,7 +48,7 @@ CHEAP_ARGS = {
"max_grad_norm": 1.0,
"do_train": True,
"do_predict": True,
"gradient_accumulation_steps": 1,
"accumulate_grad_batches": 1,
"server_ip": "",
"server_port": "",
"seed": 42,
@@ -60,7 +60,7 @@ CHEAP_ARGS = {
"weight_decay": 0.0,
"adam_epsilon": 1e-08,
"warmup_steps": 0,
"num_train_epochs": 1,
"max_epochs": 1,
"train_batch_size": 2,
"eval_batch_size": 2,
"max_source_length": 12,
@@ -122,7 +122,7 @@ class TestSummarizationDistiller(unittest.TestCase):
updates = dict(
student_encoder_layers=2,
student_decoder_layers=1,
num_train_epochs=4,
max_epochs=4,
val_check_interval=0.25,
alpha_hid=2.0,
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
@@ -156,7 +156,7 @@ class TestSummarizationDistiller(unittest.TestCase):
default_updates = dict(
train_batch_size=1,
eval_batch_size=2,
num_train_epochs=2,
max_epochs=2,
alpha_mlm=0.2,
alpha_ce=0.8,
do_predict=True,
@@ -187,7 +187,7 @@ class TestSummarizationDistiller(unittest.TestCase):
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"])
self.assertIsInstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
desired_n_evals = int(args_d["num_train_epochs"] * (1 / args_d["val_check_interval"]) + 1)
desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) + 1)
self.assertEqual(len(metrics["val"]), desired_n_evals)
self.assertEqual(len(metrics["test"]), 1)
return model

View File

@@ -17,5 +17,5 @@ python finetune.py \
--model_name_or_path facebook/mbart-large-cc25 \
--task translation \
--warmup_steps 500 \
--logger wandb --sortish_sampler \
--logger_name wandb --sortish_sampler \
$@