diff --git a/examples/summarization/t5/evaluate_cnn.py b/examples/summarization/t5/evaluate_cnn.py index 18750183ac..535c11093b 100644 --- a/examples/summarization/t5/evaluate_cnn.py +++ b/examples/summarization/t5/evaluate_cnn.py @@ -14,13 +14,13 @@ def chunks(lst, n): yield lst[i : i + n] -def generate_summaries(lns, output_file_path, batch_size, device): +def generate_summaries(lns, output_file_path, model_size, batch_size, device): output_file = Path(output_file_path).open("w") - model = T5ForConditionalGeneration.from_pretrained("t5-large") + model = T5ForConditionalGeneration.from_pretrained(model_size) model.to(device) - tokenizer = T5Tokenizer.from_pretrained("t5-large") + tokenizer = T5Tokenizer.from_pretrained(model_size) # update config with summarization specific params task_specific_params = model.config.task_specific_params @@ -61,6 +61,12 @@ def calculate_rouge(output_lns, reference_lns, score_path): def run_generate(): parser = argparse.ArgumentParser() + parser.add_argument( + "model_size", + type=str, + help="T5 model size, either 't5-small', 't5-base' or 't5-large'. Defaults to base.", + default="t5-base", + ) parser.add_argument( "input_path", type=str, help="like cnn_dm/test_articles_input.txt", ) @@ -83,7 +89,7 @@ def run_generate(): source_lns = [x.rstrip() for x in open(args.input_path).readlines()] - generate_summaries(source_lns, args.output_path, args.batch_size, args.device) + generate_summaries(source_lns, args.output_path, args.model_size, args.batch_size, args.device) output_lns = [x.rstrip() for x in open(args.output_path).readlines()] reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()] diff --git a/examples/summarization/t5/test_t5_examples.py b/examples/summarization/t5/test_t5_examples.py index eb24c31c89..58b5db681e 100644 --- a/examples/summarization/t5/test_t5_examples.py +++ b/examples/summarization/t5/test_t5_examples.py @@ -22,7 +22,7 @@ class TestT5Examples(unittest.TestCase): tmp = Path(tempfile.gettempdir()) / "utest_generations.hypo" with tmp.open("w") as f: f.write("\n".join(articles)) - testargs = ["evaluate_cnn.py", str(tmp), "output.txt", str(tmp), "score.txt"] + testargs = ["evaluate_cnn.py", "t5-small", str(tmp), "output.txt", str(tmp), "score.txt"] with patch.object(sys, "argv", testargs): run_generate() self.assertTrue(Path("output.txt").exists())