From 393b8dc09a97197df1937a7e86c0c6b4ce69c7e9 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Fri, 26 Jun 2020 19:20:43 -0400 Subject: [PATCH] examples/seq2seq/run_eval.py fixes and docs (#5322) --- examples/seq2seq/README.md | 45 +++++++++++++++++++-- examples/seq2seq/run_eval.py | 48 +++++++++++++++-------- examples/seq2seq/test_seq2seq_examples.py | 2 +- examples/seq2seq/utils.py | 5 ++- tests/test_modeling_bart.py | 6 +-- 5 files changed, 79 insertions(+), 27 deletions(-) diff --git a/examples/seq2seq/README.md b/examples/seq2seq/README.md index 832e0943e5..3f74c9bdb2 100644 --- a/examples/seq2seq/README.md +++ b/examples/seq2seq/README.md @@ -37,13 +37,50 @@ export ENRO_DIR=${PWD}/wmt_en_ro If you are using your own data, it must be formatted as one directory with 6 files: train.source, train.target, val.source, val.target, test.source, test.target. The `.source` files are the input, the `.target` files are the desired output. -### Evaluation +### Evaluation Commands -To create summaries for each article in dataset, run: +To create summaries for each article in dataset, we use `run_eval.py`, here are a few commands that run eval for different tasks and models. +If 'translation' is in your task name, the computed metric will be BLEU. Otherwise, ROUGE will be used. + +For t5, you need to specify --task translation_{src}_to_{tgt} as follows: ```bash -python run_eval.py test_generations.txt --score_path rouge_scores.txt +export DATA_DIR=wmt_en_ro +python run_eval.py t5_base \ + $DATA_DIR/val.source mbart_val_generations.txt \ + --reference_path $DATA_DIR/val.target \ + --score_path enro_bleu.json \ + --task translation_en_to_ro \ + --n_obs 100 \ + --device cuda \ + --fp16 \ + --bs 32 +``` + +This command works for MBART, although the BLEU score is suspiciously low. +```bash +export DATA_DIR=wmt_en_ro +python run_eval.py facebook/mbart-large-en-ro $DATA_DIR/val.source mbart_val_generations.txt \ + --reference_path $DATA_DIR/val.target \ + --score_path enro_bleu.json \ + --task translation \ + --n_obs 100 \ + --device cuda \ + --fp16 \ + --bs 32 +``` + +Summarization (xsum will be very similar): +```bash +export DATA_DIR=cnn_dm +python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_generations.txt \ + --reference_path $DATA_DIR/val.target \ + --score_path cnn_rouge.json \ + --task summarization \ + --n_obs 100 \ + --device cuda \ + --fp16 \ + --bs 32 ``` -The default batch size, 4, fits in 16GB GPU memory, but may need to be adjusted to fit your system. ### Summarization Finetuning diff --git a/examples/seq2seq/run_eval.py b/examples/seq2seq/run_eval.py index 6a0480f36d..3f92d56a29 100644 --- a/examples/seq2seq/run_eval.py +++ b/examples/seq2seq/run_eval.py @@ -9,9 +9,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer try: - from .utils import calculate_rouge, use_task_specific_params, calculate_bleu_score + from .utils import calculate_rouge, use_task_specific_params, calculate_bleu_score, trim_batch except ImportError: - from utils import calculate_rouge, use_task_specific_params, calculate_bleu_score + from utils import calculate_rouge, use_task_specific_params, calculate_bleu_score, trim_batch DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" @@ -29,6 +29,7 @@ def generate_summaries_or_translations( batch_size: int = 8, device: str = DEFAULT_DEVICE, fp16=False, + task="summarization", **gen_kwargs, ) -> None: fout = Path(out_file).open("w", encoding="utf-8") @@ -40,7 +41,7 @@ def generate_summaries_or_translations( tokenizer = AutoTokenizer.from_pretrained(model_name) # update config with summarization specific params - use_task_specific_params(model, "summarization") + use_task_specific_params(model, task) for batch in tqdm(list(chunks(examples, batch_size))): if "t5" in model_name: @@ -48,7 +49,8 @@ def generate_summaries_or_translations( batch = tokenizer(batch, max_length=1024, return_tensors="pt", truncation=True, padding="max_length").to( device ) - summaries = model.generate(**batch, **gen_kwargs) + input_ids, attention_mask = trim_batch(**batch, pad_token_id=tokenizer.pad_token_id) + summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs) dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False) for hypothesis in dec: fout.write(hypothesis + "\n") @@ -57,30 +59,42 @@ def generate_summaries_or_translations( def run_generate(): parser = argparse.ArgumentParser() - parser.add_argument("input_path", type=str, help="like cnn_dm/test.source") - parser.add_argument("output_path", type=str, help="where to save summaries") parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.") + parser.add_argument("input_path", type=str, help="like cnn_dm/test.source") + parser.add_argument("save_path", type=str, help="where to save summaries") + parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt") parser.add_argument("--score_path", type=str, required=False, help="where to save the rouge score in json format") - parser.add_argument("--metric", type=str, choices=["bleu", "rouge"], default="rouge") parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.") + parser.add_argument("--task", type=str, default="summarization", help="typically translation or summarization") parser.add_argument("--bs", type=int, default=8, required=False, help="batch size") + parser.add_argument( + "--n_obs", type=int, default=-1, required=False, help="How many observations. Defaults to all." + ) parser.add_argument("--fp16", action="store_true") args = parser.parse_args() examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()] + if args.n_obs > 0: + examples = examples[: args.n_obs] generate_summaries_or_translations( - examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device, fp16=args.fp16 + examples, + args.save_path, + args.model_name, + batch_size=args.bs, + device=args.device, + fp16=args.fp16, + task=args.task, ) - - output_lns = [x.rstrip() for x in open(args.output_path).readlines()] - scores = {} - if args.reference_path is not None: - score_fn = {"bleu": calculate_bleu_score, "rouge": calculate_rouge}[args.metric] - reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()] - scores: dict = score_fn(output_lns, reference_lns) - if args.score_path is not None: - json.dump(scores, open("score_path", "w+")) + if args.reference_path is None: + return + # Compute scores + score_fn = calculate_bleu_score if "translation" in args.task else calculate_rouge + output_lns = [x.rstrip() for x in open(args.save_path).readlines()] + reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()][: len(output_lns)] + scores: dict = score_fn(output_lns, reference_lns) + if args.score_path is not None: + json.dump(scores, open(args.score_path, "w+")) return scores diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index 1826090c6d..bd1b108660 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -198,7 +198,7 @@ def test_run_eval_bart(model): assert not output_file_name.exists() articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] _dump_articles(input_file_name, articles) - testargs = ["run_eval.py", str(input_file_name), str(output_file_name), model] # TODO: test score_path + testargs = ["run_eval.py", model, str(input_file_name), str(output_file_name)] # TODO: test score_path with patch.object(sys, "argv", testargs): run_generate() assert Path(output_file_name).exists() diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index 99a2abbe20..7e2b862dd8 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -60,8 +60,9 @@ def lmap(f: Callable, x: Iterable) -> List: return list(map(f, x)) -def calculate_bleu_score(output_lns, refs_lns) -> dict: - return {"bleu": corpus_bleu(output_lns, [refs_lns]).score} +def calculate_bleu_score(output_lns, refs_lns, **kwargs) -> dict: + """Uses sacrebleu's corpus_bleu implementation.""" + return {"bleu": corpus_bleu(output_lns, [refs_lns], **kwargs).score} def trim_batch( diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 209abbb211..418dbeea2c 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -253,9 +253,9 @@ class MBartIntegrationTests(unittest.TestCase): with torch.no_grad(): logits, *other_stuff = model(**net_input) - expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=torch_device, dtype=model.dtype) - result_slice = logits[0][0][:3] - self.assertTrue(torch.allclose(expected_slice, result_slice, atol=TOLERANCE)) + expected_slice = [9.0078, 10.1113, 14.4787] + result_slice = logits[0][0][:3].tolist() + self.assertListEqual(expected_slice, result_slice) @slow def test_enro_generate(self):