Lightning Updates for v0.8.5 (#5798)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -60,7 +60,7 @@ Summarization Tips:
|
||||
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
|
||||
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
|
||||
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
|
||||
- `wandb` can be used by specifying `--logger wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task.
|
||||
- `wandb` can be used by specifying `--logger_name wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task.
|
||||
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
|
||||
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
|
||||
|
||||
@@ -124,7 +124,7 @@ model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr')
|
||||
```
|
||||
|
||||
#### XSUM Shared Task
|
||||
Compare XSUM results with others by using `--logger wandb_shared`. This requires `wandb` registration.
|
||||
Compare XSUM results with others by using `--logger_name wandb_shared`. This requires `wandb` registration.
|
||||
|
||||
Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier!
|
||||
```bash
|
||||
@@ -135,7 +135,7 @@ WANDB_PROJECT='hf_xsum' ./finetune.sh \
|
||||
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
|
||||
--num_train_epochs 6 \
|
||||
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
|
||||
--logger wandb
|
||||
--logger_name wandb
|
||||
```
|
||||
|
||||
You can see your wandb logs [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-)
|
||||
|
||||
@@ -221,8 +221,8 @@ class SummarizationModule(BaseTransformer):
|
||||
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
||||
t_total = (
|
||||
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
|
||||
// self.hparams.gradient_accumulation_steps
|
||||
* float(self.hparams.num_train_epochs)
|
||||
// self.hparams.accumulate_grad_batches
|
||||
* float(self.hparams.max_epochs)
|
||||
)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
|
||||
@@ -279,7 +279,7 @@ class SummarizationModule(BaseTransformer):
|
||||
parser.add_argument("--freeze_encoder", action="store_true")
|
||||
parser.add_argument("--freeze_embeds", action="store_true")
|
||||
parser.add_argument("--sortish_sampler", action="store_true", default=False)
|
||||
parser.add_argument("--logger", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
||||
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
||||
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
@@ -288,7 +288,6 @@ class SummarizationModule(BaseTransformer):
|
||||
)
|
||||
parser.add_argument("--src_lang", type=str, default="", required=False)
|
||||
parser.add_argument("--tgt_lang", type=str, default="", required=False)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -318,22 +317,24 @@ def main(args, model=None) -> SummarizationModule:
|
||||
model: SummarizationModule = SummarizationModule(args)
|
||||
else:
|
||||
model: SummarizationModule = TranslationModule(args)
|
||||
|
||||
dataset = Path(args.data_dir).name
|
||||
if (
|
||||
args.logger == "default"
|
||||
args.logger_name == "default"
|
||||
or args.fast_dev_run
|
||||
or str(args.output_dir).startswith("/tmp")
|
||||
or str(args.output_dir).startswith("/var")
|
||||
):
|
||||
logger = True # don't pollute wandb logs unnecessarily
|
||||
elif args.logger == "wandb":
|
||||
elif args.logger_name == "wandb":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
logger = WandbLogger(name=model.output_dir.name)
|
||||
logger = WandbLogger(name=model.output_dir.name, project=dataset)
|
||||
|
||||
elif args.logger == "wandb_shared":
|
||||
elif args.logger_name == "wandb_shared":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
logger = WandbLogger(name=model.output_dir.name)
|
||||
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
||||
trainer: pl.Trainer = generic_train(
|
||||
model,
|
||||
args,
|
||||
@@ -352,13 +353,17 @@ def main(args, model=None) -> SummarizationModule:
|
||||
model.hparams.test_checkpoint = checkpoints[-1]
|
||||
trainer.resume_from_checkpoint = checkpoints[-1]
|
||||
trainer.logger.log_hyperparams(model.hparams)
|
||||
trainer.test(model) # this breaks in DDP, known lightning issue. See evaluate_checkpoint to recover metrics.
|
||||
|
||||
# test() without a model tests using the best checkpoint automatically
|
||||
trainer.test()
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
||||
@@ -10,5 +10,4 @@ python finetune.py \
|
||||
--do_predict \
|
||||
--n_val 1000 \
|
||||
--val_check_interval 0.1 \
|
||||
--sortish_sampler \
|
||||
$@
|
||||
|
||||
@@ -26,7 +26,7 @@ logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger()
|
||||
CUDA_AVAILABLE = torch.cuda.is_available()
|
||||
CHEAP_ARGS = {
|
||||
"logger": "default",
|
||||
"logger_name": "default",
|
||||
"length_penalty": 0.5,
|
||||
"cache_dir": "",
|
||||
"task": "summarization",
|
||||
@@ -48,7 +48,7 @@ CHEAP_ARGS = {
|
||||
"max_grad_norm": 1.0,
|
||||
"do_train": True,
|
||||
"do_predict": True,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"accumulate_grad_batches": 1,
|
||||
"server_ip": "",
|
||||
"server_port": "",
|
||||
"seed": 42,
|
||||
@@ -60,7 +60,7 @@ CHEAP_ARGS = {
|
||||
"weight_decay": 0.0,
|
||||
"adam_epsilon": 1e-08,
|
||||
"warmup_steps": 0,
|
||||
"num_train_epochs": 1,
|
||||
"max_epochs": 1,
|
||||
"train_batch_size": 2,
|
||||
"eval_batch_size": 2,
|
||||
"max_source_length": 12,
|
||||
@@ -122,7 +122,7 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
updates = dict(
|
||||
student_encoder_layers=2,
|
||||
student_decoder_layers=1,
|
||||
num_train_epochs=4,
|
||||
max_epochs=4,
|
||||
val_check_interval=0.25,
|
||||
alpha_hid=2.0,
|
||||
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
|
||||
@@ -156,7 +156,7 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
default_updates = dict(
|
||||
train_batch_size=1,
|
||||
eval_batch_size=2,
|
||||
num_train_epochs=2,
|
||||
max_epochs=2,
|
||||
alpha_mlm=0.2,
|
||||
alpha_ce=0.8,
|
||||
do_predict=True,
|
||||
@@ -187,7 +187,7 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
|
||||
self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"])
|
||||
self.assertIsInstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
|
||||
desired_n_evals = int(args_d["num_train_epochs"] * (1 / args_d["val_check_interval"]) + 1)
|
||||
desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) + 1)
|
||||
self.assertEqual(len(metrics["val"]), desired_n_evals)
|
||||
self.assertEqual(len(metrics["test"]), 1)
|
||||
return model
|
||||
|
||||
@@ -17,5 +17,5 @@ python finetune.py \
|
||||
--model_name_or_path facebook/mbart-large-cc25 \
|
||||
--task translation \
|
||||
--warmup_steps 500 \
|
||||
--logger wandb --sortish_sampler \
|
||||
--logger_name wandb --sortish_sampler \
|
||||
$@
|
||||
|
||||
Reference in New Issue
Block a user