examples/seq2seq: never override $WANDB_PROJECT (#5407)
This commit is contained in:
@@ -71,7 +71,7 @@ Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_t
|
|||||||
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
|
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
|
||||||
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
|
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
|
||||||
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
|
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
|
||||||
- `wandb` can be used by specifying `--logger wandb_shared` or `--logger wandb`. It is useful for reproducibility.
|
- `wandb` can be used by specifying `--logger wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task.
|
||||||
- This warning can be safely ignored:
|
- This warning can be safely ignored:
|
||||||
> "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']"
|
> "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']"
|
||||||
- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start).
|
- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start).
|
||||||
@@ -111,14 +111,14 @@ Compare XSUM results with others by using `--logger wandb_shared`. This requires
|
|||||||
|
|
||||||
Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier!
|
Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier!
|
||||||
```bash
|
```bash
|
||||||
./finetune.sh \
|
WANDB_PROJECT='hf_xsum' ./finetune.sh \
|
||||||
--data_dir $XSUM_DIR \
|
--data_dir $XSUM_DIR \
|
||||||
--output_dir xsum_frozen_embs \
|
--output_dir xsum_frozen_embs \
|
||||||
--model_name_or_path facebook/bart-large \
|
--model_name_or_path facebook/bart-large \
|
||||||
--logger wandb_shared \
|
|
||||||
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
|
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
|
||||||
--num_train_epochs 6 \
|
--num_train_epochs 6 \
|
||||||
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100
|
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
|
||||||
|
--logger wandb
|
||||||
```
|
```
|
||||||
|
|
||||||
You can see your wandb logs [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-)
|
You can see your wandb logs [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-)
|
||||||
|
|||||||
@@ -298,8 +298,6 @@ def main(args, model=None) -> SummarizationModule:
|
|||||||
model: SummarizationModule = SummarizationModule(args)
|
model: SummarizationModule = SummarizationModule(args)
|
||||||
else:
|
else:
|
||||||
model: SummarizationModule = TranslationModule(args)
|
model: SummarizationModule = TranslationModule(args)
|
||||||
|
|
||||||
dataset = Path(args.data_dir).name
|
|
||||||
if (
|
if (
|
||||||
args.logger == "default"
|
args.logger == "default"
|
||||||
or args.fast_dev_run
|
or args.fast_dev_run
|
||||||
@@ -310,12 +308,12 @@ def main(args, model=None) -> SummarizationModule:
|
|||||||
elif args.logger == "wandb":
|
elif args.logger == "wandb":
|
||||||
from pytorch_lightning.loggers import WandbLogger
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
|
|
||||||
logger = WandbLogger(name=model.output_dir.name, project=dataset)
|
logger = WandbLogger(name=model.output_dir.name)
|
||||||
|
|
||||||
elif args.logger == "wandb_shared":
|
elif args.logger == "wandb_shared":
|
||||||
from pytorch_lightning.loggers import WandbLogger
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
|
|
||||||
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
logger = WandbLogger(name=model.output_dir.name)
|
||||||
trainer: pl.Trainer = generic_train(
|
trainer: pl.Trainer = generic_train(
|
||||||
model,
|
model,
|
||||||
args,
|
args,
|
||||||
|
|||||||
Reference in New Issue
Block a user