diff --git a/examples/seq2seq/README.md b/examples/seq2seq/README.md index 29866c09a9..762dd5b4b9 100644 --- a/examples/seq2seq/README.md +++ b/examples/seq2seq/README.md @@ -71,7 +71,7 @@ Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_t - If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter. - For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()` - `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM. -- `wandb` can be used by specifying `--logger wandb_shared` or `--logger wandb`. It is useful for reproducibility. +- `wandb` can be used by specifying `--logger wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task. - This warning can be safely ignored: > "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']" - Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start). @@ -111,14 +111,14 @@ Compare XSUM results with others by using `--logger wandb_shared`. This requires Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier! ```bash -./finetune.sh \ +WANDB_PROJECT='hf_xsum' ./finetune.sh \ --data_dir $XSUM_DIR \ --output_dir xsum_frozen_embs \ --model_name_or_path facebook/bart-large \ - --logger wandb_shared \ --train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \ --num_train_epochs 6 \ - --max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 + --max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \ + --logger wandb ``` You can see your wandb logs [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-) diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index 598f05f410..fd1d00ecf3 100644 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -298,8 +298,6 @@ def main(args, model=None) -> SummarizationModule: model: SummarizationModule = SummarizationModule(args) else: model: SummarizationModule = TranslationModule(args) - - dataset = Path(args.data_dir).name if ( args.logger == "default" or args.fast_dev_run @@ -310,12 +308,12 @@ def main(args, model=None) -> SummarizationModule: elif args.logger == "wandb": from pytorch_lightning.loggers import WandbLogger - logger = WandbLogger(name=model.output_dir.name, project=dataset) + logger = WandbLogger(name=model.output_dir.name) elif args.logger == "wandb_shared": from pytorch_lightning.loggers import WandbLogger - logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}") + logger = WandbLogger(name=model.output_dir.name) trainer: pl.Trainer = generic_train( model, args,