From a316a6aaa8fcfe9ff0004b122078313f0eae0631 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 29 Jun 2020 10:36:04 -0400 Subject: [PATCH] [seq2seq docs] Move evaluation down, fix typo (#5365) --- examples/seq2seq/README.md | 94 +++++++++++++++++++------------------- 1 file changed, 48 insertions(+), 46 deletions(-) diff --git a/examples/seq2seq/README.md b/examples/seq2seq/README.md index 3f74c9bdb2..8cfaf82b5e 100644 --- a/examples/seq2seq/README.md +++ b/examples/seq2seq/README.md @@ -3,6 +3,7 @@ Summarization support is more mature than translation support. Please tag @sshleifer with any issues/unexpected behaviors, or send a PR! For `bertabs` instructions, see `bertabs/README.md`. + ### Data CNN/DailyMail data @@ -37,50 +38,6 @@ export ENRO_DIR=${PWD}/wmt_en_ro If you are using your own data, it must be formatted as one directory with 6 files: train.source, train.target, val.source, val.target, test.source, test.target. The `.source` files are the input, the `.target` files are the desired output. -### Evaluation Commands - -To create summaries for each article in dataset, we use `run_eval.py`, here are a few commands that run eval for different tasks and models. -If 'translation' is in your task name, the computed metric will be BLEU. Otherwise, ROUGE will be used. - -For t5, you need to specify --task translation_{src}_to_{tgt} as follows: -```bash -export DATA_DIR=wmt_en_ro -python run_eval.py t5_base \ - $DATA_DIR/val.source mbart_val_generations.txt \ - --reference_path $DATA_DIR/val.target \ - --score_path enro_bleu.json \ - --task translation_en_to_ro \ - --n_obs 100 \ - --device cuda \ - --fp16 \ - --bs 32 -``` - -This command works for MBART, although the BLEU score is suspiciously low. -```bash -export DATA_DIR=wmt_en_ro -python run_eval.py facebook/mbart-large-en-ro $DATA_DIR/val.source mbart_val_generations.txt \ - --reference_path $DATA_DIR/val.target \ - --score_path enro_bleu.json \ - --task translation \ - --n_obs 100 \ - --device cuda \ - --fp16 \ - --bs 32 -``` - -Summarization (xsum will be very similar): -```bash -export DATA_DIR=cnn_dm -python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_generations.txt \ - --reference_path $DATA_DIR/val.target \ - --score_path cnn_rouge.json \ - --task summarization \ - --n_obs 100 \ - --device cuda \ - --fp16 \ - --bs 32 -``` ### Summarization Finetuning @@ -147,8 +104,7 @@ from transformers import AutoModelForSeq2SeqLM model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr') ``` - -### XSUM Shared Task +#### XSUM Shared Task Compare XSUM results with others by using `--logger wandb_shared`. This requires `wandb` registration. Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier! @@ -165,6 +121,52 @@ Here is an example command, but you can do whatever you want. Hopefully this wil You can see your wandb logs [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-) +### Evaluation Commands + +To create summaries for each article in dataset, we use `run_eval.py`, here are a few commands that run eval for different tasks and models. +If 'translation' is in your task name, the computed metric will be BLEU. Otherwise, ROUGE will be used. + +For t5, you need to specify --task translation_{src}_to_{tgt} as follows: +```bash +export DATA_DIR=wmt_en_ro +python run_eval.py t5_base \ + $DATA_DIR/val.source t5_val_generations.txt \ + --reference_path $DATA_DIR/val.target \ + --score_path enro_bleu.json \ + --task translation_en_to_ro \ + --n_obs 100 \ + --device cuda \ + --fp16 \ + --bs 32 +``` + +This command works for MBART, although the BLEU score is suspiciously low. +```bash +export DATA_DIR=wmt_en_ro +python run_eval.py facebook/mbart-large-en-ro $DATA_DIR/val.source mbart_val_generations.txt \ + --reference_path $DATA_DIR/val.target \ + --score_path enro_bleu.json \ + --task translation \ + --n_obs 100 \ + --device cuda \ + --fp16 \ + --bs 32 +``` + +Summarization (xsum will be very similar): +```bash +export DATA_DIR=cnn_dm +python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_generations.txt \ + --reference_path $DATA_DIR/val.target \ + --score_path cnn_rouge.json \ + --task summarization \ + --n_obs 100 \ + --device cuda \ + --fp16 \ + --bs 32 +``` + + ### DistilBART For the CNN/DailyMail dataset, (relatively longer, more extractive summaries), we found a simple technique that works: