examples/seq2seq/run_eval.py fixes and docs (#5322)

This commit is contained in:
Sam Shleifer
2020-06-26 19:20:43 -04:00
committed by GitHub
parent 5543b30aa6
commit 393b8dc09a
5 changed files with 79 additions and 27 deletions

View File

@@ -37,13 +37,50 @@ export ENRO_DIR=${PWD}/wmt_en_ro
If you are using your own data, it must be formatted as one directory with 6 files: train.source, train.target, val.source, val.target, test.source, test.target.
The `.source` files are the input, the `.target` files are the desired output.
### Evaluation
### Evaluation Commands
To create summaries for each article in dataset, run:
To create summaries for each article in dataset, we use `run_eval.py`, here are a few commands that run eval for different tasks and models.
If 'translation' is in your task name, the computed metric will be BLEU. Otherwise, ROUGE will be used.
For t5, you need to specify --task translation_{src}_to_{tgt} as follows:
```bash
python run_eval.py <path_to_test.source> test_generations.txt <model-name> --score_path rouge_scores.txt
export DATA_DIR=wmt_en_ro
python run_eval.py t5_base \
$DATA_DIR/val.source mbart_val_generations.txt \
--reference_path $DATA_DIR/val.target \
--score_path enro_bleu.json \
--task translation_en_to_ro \
--n_obs 100 \
--device cuda \
--fp16 \
--bs 32
```
This command works for MBART, although the BLEU score is suspiciously low.
```bash
export DATA_DIR=wmt_en_ro
python run_eval.py facebook/mbart-large-en-ro $DATA_DIR/val.source mbart_val_generations.txt \
--reference_path $DATA_DIR/val.target \
--score_path enro_bleu.json \
--task translation \
--n_obs 100 \
--device cuda \
--fp16 \
--bs 32
```
Summarization (xsum will be very similar):
```bash
export DATA_DIR=cnn_dm
python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_generations.txt \
--reference_path $DATA_DIR/val.target \
--score_path cnn_rouge.json \
--task summarization \
--n_obs 100 \
--device cuda \
--fp16 \
--bs 32
```
The default batch size, 4, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
### Summarization Finetuning

View File

@@ -9,9 +9,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
try:
from .utils import calculate_rouge, use_task_specific_params, calculate_bleu_score
from .utils import calculate_rouge, use_task_specific_params, calculate_bleu_score, trim_batch
except ImportError:
from utils import calculate_rouge, use_task_specific_params, calculate_bleu_score
from utils import calculate_rouge, use_task_specific_params, calculate_bleu_score, trim_batch
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -29,6 +29,7 @@ def generate_summaries_or_translations(
batch_size: int = 8,
device: str = DEFAULT_DEVICE,
fp16=False,
task="summarization",
**gen_kwargs,
) -> None:
fout = Path(out_file).open("w", encoding="utf-8")
@@ -40,7 +41,7 @@ def generate_summaries_or_translations(
tokenizer = AutoTokenizer.from_pretrained(model_name)
# update config with summarization specific params
use_task_specific_params(model, "summarization")
use_task_specific_params(model, task)
for batch in tqdm(list(chunks(examples, batch_size))):
if "t5" in model_name:
@@ -48,7 +49,8 @@ def generate_summaries_or_translations(
batch = tokenizer(batch, max_length=1024, return_tensors="pt", truncation=True, padding="max_length").to(
device
)
summaries = model.generate(**batch, **gen_kwargs)
input_ids, attention_mask = trim_batch(**batch, pad_token_id=tokenizer.pad_token_id)
summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for hypothesis in dec:
fout.write(hypothesis + "\n")
@@ -57,30 +59,42 @@ def generate_summaries_or_translations(
def run_generate():
parser = argparse.ArgumentParser()
parser.add_argument("input_path", type=str, help="like cnn_dm/test.source")
parser.add_argument("output_path", type=str, help="where to save summaries")
parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.")
parser.add_argument("input_path", type=str, help="like cnn_dm/test.source")
parser.add_argument("save_path", type=str, help="where to save summaries")
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt")
parser.add_argument("--score_path", type=str, required=False, help="where to save the rouge score in json format")
parser.add_argument("--metric", type=str, choices=["bleu", "rouge"], default="rouge")
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
parser.add_argument("--task", type=str, default="summarization", help="typically translation or summarization")
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
parser.add_argument(
"--n_obs", type=int, default=-1, required=False, help="How many observations. Defaults to all."
)
parser.add_argument("--fp16", action="store_true")
args = parser.parse_args()
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
if args.n_obs > 0:
examples = examples[: args.n_obs]
generate_summaries_or_translations(
examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device, fp16=args.fp16
examples,
args.save_path,
args.model_name,
batch_size=args.bs,
device=args.device,
fp16=args.fp16,
task=args.task,
)
output_lns = [x.rstrip() for x in open(args.output_path).readlines()]
scores = {}
if args.reference_path is not None:
score_fn = {"bleu": calculate_bleu_score, "rouge": calculate_rouge}[args.metric]
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()]
scores: dict = score_fn(output_lns, reference_lns)
if args.score_path is not None:
json.dump(scores, open("score_path", "w+"))
if args.reference_path is None:
return
# Compute scores
score_fn = calculate_bleu_score if "translation" in args.task else calculate_rouge
output_lns = [x.rstrip() for x in open(args.save_path).readlines()]
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()][: len(output_lns)]
scores: dict = score_fn(output_lns, reference_lns)
if args.score_path is not None:
json.dump(scores, open(args.score_path, "w+"))
return scores

View File

@@ -198,7 +198,7 @@ def test_run_eval_bart(model):
assert not output_file_name.exists()
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
_dump_articles(input_file_name, articles)
testargs = ["run_eval.py", str(input_file_name), str(output_file_name), model] # TODO: test score_path
testargs = ["run_eval.py", model, str(input_file_name), str(output_file_name)] # TODO: test score_path
with patch.object(sys, "argv", testargs):
run_generate()
assert Path(output_file_name).exists()

View File

@@ -60,8 +60,9 @@ def lmap(f: Callable, x: Iterable) -> List:
return list(map(f, x))
def calculate_bleu_score(output_lns, refs_lns) -> dict:
return {"bleu": corpus_bleu(output_lns, [refs_lns]).score}
def calculate_bleu_score(output_lns, refs_lns, **kwargs) -> dict:
"""Uses sacrebleu's corpus_bleu implementation."""
return {"bleu": corpus_bleu(output_lns, [refs_lns], **kwargs).score}
def trim_batch(