From 80a169451479f97d737e2be433a7cbd30c39c6bb Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 16 Apr 2020 20:00:41 +0200 Subject: [PATCH] [Examples, T5] Change newstest2013 to newstest2014 and clean up (#3817) * Refactored use of newstest2013 to newstest2014. Fixed bug where argparse consumed first command line argument as model_size argument rather than using default model_size by forcing explicit --model_size flag inclusion * More pythonic file handling through 'with' context * COSMETIC - ran Black and isort * Fixed reference to number of lines in newstest2014 * Fixed failing test. More pythonic file handling * finish PR from tholiao * remove outcommented lines * make style * make isort happy Co-authored-by: Thomas Liao --- examples/translation/t5/README.md | 16 ++++----- examples/translation/t5/evaluate_wmt.py | 45 ++++++++++++++----------- 2 files changed, 34 insertions(+), 27 deletions(-) diff --git a/examples/translation/t5/README.md b/examples/translation/t5/README.md index 85a179587a..7abcfb8a85 100644 --- a/examples/translation/t5/README.md +++ b/examples/translation/t5/README.md @@ -9,17 +9,17 @@ evaluated on the WMT English-German dataset. To be able to reproduce the authors' results on WMT English to German, you first need to download the WMT14 en-de news datasets. -Go on Stanford's official NLP [website](https://nlp.stanford.edu/projects/nmt/) and find "newstest2013.en" and "newstest2013.de" under WMT'14 English-German data or download the dataset directly via: +Go on Stanford's official NLP [website](https://nlp.stanford.edu/projects/nmt/) and find "newstest2014.en" and "newstest2014.de" under WMT'14 English-German data or download the dataset directly via: ```bash -curl https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2013.en > newstest2013.en -curl https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2013.de > newstest2013.de +curl https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2014.en > newstest2014.en +curl https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2014.de > newstest2014.de ``` -You should have 3000 sentence in each file. You can verify this by running: +You should have 2737 sentences in each file. You can verify this by running: ```bash -wc -l newstest2013.en # should give 3000 +wc -l newstest2014.en # should give 2737 ``` ### Usage @@ -29,8 +29,8 @@ Let's check the longest and shortest sentence in our file to find reasonable dec Get the longest and shortest sentence: ```bash -awk '{print NF}' newstest2013.en | sort -n | head -1 # shortest sentence has 1 word -awk '{print NF}' newstest2013.en | sort -n | tail -1 # longest sentence has 106 words +awk '{print NF}' newstest2014.en | sort -n | head -1 # shortest sentence has 2 word +awk '{print NF}' newstest2014.en | sort -n | tail -1 # longest sentence has 91 words ``` We will set our `max_length` to ~3 times the longest sentence and leave `min_length` to its default value of 0. @@ -38,7 +38,7 @@ We decode with beam search `num_beams=4` as proposed in the paper. Also as is co To create translation for each in dataset and get a final BLEU score, run: ```bash -python evaluate_wmt.py newstest2013_de_translations.txt newsstest2013_en_de_bleu.txt +python evaluate_wmt.py newstest2014_de_translations.txt newsstest2014_en_de_bleu.txt ``` the default batch size, 16, fits in 16GB GPU memory, but may need to be adjusted to fit your system. diff --git a/examples/translation/t5/evaluate_wmt.py b/examples/translation/t5/evaluate_wmt.py index 533811271b..4db5564f76 100644 --- a/examples/translation/t5/evaluate_wmt.py +++ b/examples/translation/t5/evaluate_wmt.py @@ -15,8 +15,6 @@ def chunks(lst, n): def generate_translations(lns, output_file_path, model_size, batch_size, device): - output_file = Path(output_file_path).open("w") - model = T5ForConditionalGeneration.from_pretrained(model_size) model.to(device) @@ -27,27 +25,29 @@ def generate_translations(lns, output_file_path, model_size, batch_size, device) if task_specific_params is not None: model.config.update(task_specific_params.get("translation_en_to_de", {})) - for batch in tqdm(list(chunks(lns, batch_size))): - batch = [model.config.prefix + text for text in batch] + with Path(output_file_path).open("w") as output_file: + for batch in tqdm(list(chunks(lns, batch_size))): + batch = [model.config.prefix + text for text in batch] - dct = tokenizer.batch_encode_plus(batch, max_length=512, return_tensors="pt", pad_to_max_length=True) + dct = tokenizer.batch_encode_plus(batch, max_length=512, return_tensors="pt", pad_to_max_length=True) - input_ids = dct["input_ids"].to(device) - attention_mask = dct["attention_mask"].to(device) + input_ids = dct["input_ids"].to(device) + attention_mask = dct["attention_mask"].to(device) - translations = model.generate(input_ids=input_ids, attention_mask=attention_mask) - dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in translations] + translations = model.generate(input_ids=input_ids, attention_mask=attention_mask) + dec = [ + tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in translations + ] - for hypothesis in dec: - output_file.write(hypothesis + "\n") - output_file.flush() + for hypothesis in dec: + output_file.write(hypothesis + "\n") def calculate_bleu_score(output_lns, refs_lns, score_path): bleu = corpus_bleu(output_lns, [refs_lns]) result = "BLEU score: {}".format(bleu.score) - score_file = Path(score_path).open("w") - score_file.write(result) + with Path(score_path).open("w") as score_file: + score_file.write(result) def run_generate(): @@ -59,13 +59,13 @@ def run_generate(): default="t5-base", ) parser.add_argument( - "input_path", type=str, help="like wmt/newstest2013.en", + "input_path", type=str, help="like wmt/newstest2014.en", ) parser.add_argument( "output_path", type=str, help="where to save translation", ) parser.add_argument( - "reference_path", type=str, help="like wmt/newstest2013.de", + "reference_path", type=str, help="like wmt/newstest2014.de", ) parser.add_argument( "score_path", type=str, help="where to save the bleu score", @@ -82,12 +82,19 @@ def run_generate(): dash_pattern = (" ##AT##-##AT## ", "-") - input_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in open(args.input_path).readlines()] + # Read input lines into python + with open(args.input_path, "r") as input_file: + input_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in input_file.readlines()] generate_translations(input_lns, args.output_path, args.model_size, args.batch_size, args.device) - output_lns = [x.strip() for x in open(args.output_path).readlines()] - refs_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in open(args.reference_path).readlines()] + # Read generated lines into python + with open(args.output_path, "r") as output_file: + output_lns = [x.strip() for x in output_file.readlines()] + + # Read reference lines into python + with open(args.reference_path, "r") as reference_file: + refs_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in reference_file.readlines()] calculate_bleu_score(output_lns, refs_lns, args.score_path)