examples/seq2seq: never override $WANDB_PROJECT (#5407)
This commit is contained in:
@@ -71,7 +71,7 @@ Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_t
|
||||
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
|
||||
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
|
||||
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
|
||||
- `wandb` can be used by specifying `--logger wandb_shared` or `--logger wandb`. It is useful for reproducibility.
|
||||
- `wandb` can be used by specifying `--logger wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task.
|
||||
- This warning can be safely ignored:
|
||||
> "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']"
|
||||
- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start).
|
||||
@@ -111,14 +111,14 @@ Compare XSUM results with others by using `--logger wandb_shared`. This requires
|
||||
|
||||
Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier!
|
||||
```bash
|
||||
./finetune.sh \
|
||||
WANDB_PROJECT='hf_xsum' ./finetune.sh \
|
||||
--data_dir $XSUM_DIR \
|
||||
--output_dir xsum_frozen_embs \
|
||||
--model_name_or_path facebook/bart-large \
|
||||
--logger wandb_shared \
|
||||
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
|
||||
--num_train_epochs 6 \
|
||||
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100
|
||||
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
|
||||
--logger wandb
|
||||
```
|
||||
|
||||
You can see your wandb logs [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-)
|
||||
|
||||
@@ -298,8 +298,6 @@ def main(args, model=None) -> SummarizationModule:
|
||||
model: SummarizationModule = SummarizationModule(args)
|
||||
else:
|
||||
model: SummarizationModule = TranslationModule(args)
|
||||
|
||||
dataset = Path(args.data_dir).name
|
||||
if (
|
||||
args.logger == "default"
|
||||
or args.fast_dev_run
|
||||
@@ -310,12 +308,12 @@ def main(args, model=None) -> SummarizationModule:
|
||||
elif args.logger == "wandb":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
logger = WandbLogger(name=model.output_dir.name, project=dataset)
|
||||
logger = WandbLogger(name=model.output_dir.name)
|
||||
|
||||
elif args.logger == "wandb_shared":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
||||
logger = WandbLogger(name=model.output_dir.name)
|
||||
trainer: pl.Trainer = generic_train(
|
||||
model,
|
||||
args,
|
||||
|
||||
Reference in New Issue
Block a user