From 5ad2ea06af898a95744a268332431f050c62a862 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 26 Mar 2020 19:07:59 +0100 Subject: [PATCH] Add wmt translation example (#3428) * add translation example * make style * adapt docstring * add gpu device as input for example * small renaming * better README --- examples/requirements.txt | 3 +- examples/translation/t5/README.md | 51 ++++++++++++ examples/translation/t5/__init__.py | 0 examples/translation/t5/evaluate_wmt.py | 90 +++++++++++++++++++++ examples/translation/t5/test_t5_examples.py | 28 +++++++ 5 files changed, 171 insertions(+), 1 deletion(-) create mode 100644 examples/translation/t5/README.md create mode 100644 examples/translation/t5/__init__.py create mode 100644 examples/translation/t5/evaluate_wmt.py create mode 100644 examples/translation/t5/test_t5_examples.py diff --git a/examples/requirements.txt b/examples/requirements.txt index e1f1a2c114..70f1b9999a 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -3,5 +3,6 @@ tensorboard scikit-learn seqeval psutil +sacrebleu rouge-score -tensorflow_datasets +tensorflow_datasets \ No newline at end of file diff --git a/examples/translation/t5/README.md b/examples/translation/t5/README.md new file mode 100644 index 0000000000..85a179587a --- /dev/null +++ b/examples/translation/t5/README.md @@ -0,0 +1,51 @@ +***This script evaluates the multitask pre-trained checkpoint for ``t5-base`` (see paper [here](https://arxiv.org/pdf/1910.10683.pdf)) on the English to German WMT dataset. Please note that the results in the paper were attained using a model fine-tuned on translation, so that results will be worse here by approx. 1.5 BLEU points*** + +### Intro + +This example shows how T5 (here the official [paper](https://arxiv.org/abs/1910.10683)) can be +evaluated on the WMT English-German dataset. + +### Get the WMT Data + +To be able to reproduce the authors' results on WMT English to German, you first need to download +the WMT14 en-de news datasets. +Go on Stanford's official NLP [website](https://nlp.stanford.edu/projects/nmt/) and find "newstest2013.en" and "newstest2013.de" under WMT'14 English-German data or download the dataset directly via: + +```bash +curl https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2013.en > newstest2013.en +curl https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2013.de > newstest2013.de +``` + +You should have 3000 sentence in each file. You can verify this by running: + +```bash +wc -l newstest2013.en # should give 3000 +``` + +### Usage + +Let's check the longest and shortest sentence in our file to find reasonable decoding hyperparameters: + +Get the longest and shortest sentence: + +```bash +awk '{print NF}' newstest2013.en | sort -n | head -1 # shortest sentence has 1 word +awk '{print NF}' newstest2013.en | sort -n | tail -1 # longest sentence has 106 words +``` + +We will set our `max_length` to ~3 times the longest sentence and leave `min_length` to its default value of 0. +We decode with beam search `num_beams=4` as proposed in the paper. Also as is common in beam search we set `early_stopping=True` and `length_penalty=2.0`. + +To create translation for each in dataset and get a final BLEU score, run: +```bash +python evaluate_wmt.py newstest2013_de_translations.txt newsstest2013_en_de_bleu.txt +``` +the default batch size, 16, fits in 16GB GPU memory, but may need to be adjusted to fit your system. + +### Where is the code? +The core model is in `src/transformers/modeling_t5.py`. This directory only contains examples. + +### BLEU Scores + +The BLEU score is calculated using [sacrebleu](https://github.com/mjpost/sacreBLEU) by mjpost. +To get the BLEU score we used diff --git a/examples/translation/t5/__init__.py b/examples/translation/t5/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/translation/t5/evaluate_wmt.py b/examples/translation/t5/evaluate_wmt.py new file mode 100644 index 0000000000..307065d0a9 --- /dev/null +++ b/examples/translation/t5/evaluate_wmt.py @@ -0,0 +1,90 @@ +import argparse +from pathlib import Path + +import torch +from tqdm import tqdm + +from sacrebleu import corpus_bleu +from transformers import T5ForConditionalGeneration, T5Tokenizer + + +def chunks(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i : i + n] + + +def generate_translations(lns, output_file_path, batch_size, device): + output_file = Path(output_file_path).open("w") + + model = T5ForConditionalGeneration.from_pretrained("t5-base") + model.to(device) + + tokenizer = T5Tokenizer.from_pretrained("t5-base") + + # update config with summarization specific params + task_specific_params = model.config.task_specific_params + if task_specific_params is not None: + model.config.update(task_specific_params.get("translation_en_to_de", {})) + + for batch in tqdm(list(chunks(lns, batch_size))): + batch = [model.config.prefix + text for text in batch] + + dct = tokenizer.batch_encode_plus(batch, max_length=512, return_tensors="pt", pad_to_max_length=True) + + input_ids = dct["input_ids"].to(device) + attention_mask = dct["attention_mask"].to(device) + + translations = model.generate(input_ids=input_ids, attention_mask=attention_mask) + dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in translations] + + for hypothesis in dec: + output_file.write(hypothesis + "\n") + output_file.flush() + + +def calculate_bleu_score(output_lns, refs_lns, score_path): + bleu = corpus_bleu(output_lns, [refs_lns]) + result = "BLEU score: {}".format(bleu.score) + score_file = Path(score_path).open("w") + score_file.write(result) + + +def run_generate(): + parser = argparse.ArgumentParser() + parser.add_argument( + "input_path", type=str, help="like wmt/newstest2013.en", + ) + parser.add_argument( + "output_path", type=str, help="where to save translation", + ) + parser.add_argument( + "reference_path", type=str, help="like wmt/newstest2013.de", + ) + parser.add_argument( + "score_path", type=str, help="where to save the bleu score", + ) + parser.add_argument( + "--batch_size", type=int, default=16, required=False, help="batch size: how many to summarize at a time", + ) + parser.add_argument( + "--no_cuda", default=False, type=bool, help="Whether to force the execution on CPU.", + ) + + args = parser.parse_args() + args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + + dash_pattern = (" ##AT##-##AT## ", "-") + + input_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in open(args.input_path).readlines()] + + generate_translations(input_lns, args.output_path, args.batch_size, args.device) + + output_lns = [x.strip() for x in open(args.output_path).readlines()] + refs_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in open(args.reference_path).readlines()] + + calculate_bleu_score(output_lns, refs_lns, args.score_path) + + +if __name__ == "__main__": + run_generate() diff --git a/examples/translation/t5/test_t5_examples.py b/examples/translation/t5/test_t5_examples.py new file mode 100644 index 0000000000..c7d1fe6882 --- /dev/null +++ b/examples/translation/t5/test_t5_examples.py @@ -0,0 +1,28 @@ +import logging +import sys +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +from .evaluate_wmt import run_generate + + +text = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() + + +class TestT5Examples(unittest.TestCase): + def test_t5_cli(self): + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + tmp = Path(tempfile.gettempdir()) / "utest_generations.hypo" + with tmp.open("w") as f: + f.write("\n".join(text)) + testargs = ["evaluate_cnn.py", str(tmp), "output.txt", str(tmp), "score.txt"] + with patch.object(sys, "argv", testargs): + run_generate() + self.assertTrue(Path("output.txt").exists())