examples/seq2seq: never override $WANDB_PROJECT (#5407)

This commit is contained in:
Sam Shleifer
2020-06-30 15:29:13 -04:00
committed by GitHub
parent 32d2031458
commit 27a7fe7a8d
2 changed files with 6 additions and 8 deletions

View File

@@ -71,7 +71,7 @@ Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_t
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter. - If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()` - For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM. - `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
- `wandb` can be used by specifying `--logger wandb_shared` or `--logger wandb`. It is useful for reproducibility. - `wandb` can be used by specifying `--logger wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task.
- This warning can be safely ignored: - This warning can be safely ignored:
> "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']" > "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']"
- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start). - Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start).
@@ -111,14 +111,14 @@ Compare XSUM results with others by using `--logger wandb_shared`. This requires
Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier! Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier!
```bash ```bash
./finetune.sh \ WANDB_PROJECT='hf_xsum' ./finetune.sh \
--data_dir $XSUM_DIR \ --data_dir $XSUM_DIR \
--output_dir xsum_frozen_embs \ --output_dir xsum_frozen_embs \
--model_name_or_path facebook/bart-large \ --model_name_or_path facebook/bart-large \
--logger wandb_shared \
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \ --train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
--num_train_epochs 6 \ --num_train_epochs 6 \
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 --max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
--logger wandb
``` ```
You can see your wandb logs [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-) You can see your wandb logs [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-)

View File

@@ -298,8 +298,6 @@ def main(args, model=None) -> SummarizationModule:
model: SummarizationModule = SummarizationModule(args) model: SummarizationModule = SummarizationModule(args)
else: else:
model: SummarizationModule = TranslationModule(args) model: SummarizationModule = TranslationModule(args)
dataset = Path(args.data_dir).name
if ( if (
args.logger == "default" args.logger == "default"
or args.fast_dev_run or args.fast_dev_run
@@ -310,12 +308,12 @@ def main(args, model=None) -> SummarizationModule:
elif args.logger == "wandb": elif args.logger == "wandb":
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(name=model.output_dir.name, project=dataset) logger = WandbLogger(name=model.output_dir.name)
elif args.logger == "wandb_shared": elif args.logger == "wandb_shared":
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}") logger = WandbLogger(name=model.output_dir.name)
trainer: pl.Trainer = generic_train( trainer: pl.Trainer = generic_train(
model, model,
args, args,