From 8deff3acf2fd34e2a2161a4b833f1ff78a0d5d52 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 30 Mar 2020 12:28:27 -0400 Subject: [PATCH] =?UTF-8?q?[bart-tiny-random]=20Put=20a=205MB=20model=20on?= =?UTF-8?q?=20S3=20to=20allow=20faster=20exampl=E2=80=A6=20(#3488)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/summarization/bart/evaluate_cnn.py | 15 ++++++++++----- examples/summarization/bart/test_bart_examples.py | 3 ++- tests/test_modeling_bart.py | 11 +++++++++++ 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/examples/summarization/bart/evaluate_cnn.py b/examples/summarization/bart/evaluate_cnn.py index 5c69dc921f..0903e0c0f9 100644 --- a/examples/summarization/bart/evaluate_cnn.py +++ b/examples/summarization/bart/evaluate_cnn.py @@ -16,15 +16,17 @@ def chunks(lst, n): yield lst[i : i + n] -def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE): +def generate_summaries( + examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE +): fout = Path(out_file).open("w") - model = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(device) + model = BartForConditionalGeneration.from_pretrained(model_name, output_past=True,).to(device) tokenizer = BartTokenizer.from_pretrained("bart-large") max_length = 140 min_length = 55 - for batch in tqdm(list(chunks(lns, batch_size))): + for batch in tqdm(list(chunks(examples, batch_size))): dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True) summaries = model.generate( input_ids=dct["input_ids"].to(device), @@ -51,6 +53,9 @@ def _run_generate(): parser.add_argument( "output_path", type=str, help="where to save summaries", ) + parser.add_argument( + "model_name", type=str, default="bart-large-cnn", help="like bart-large-cnn", + ) parser.add_argument( "--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.", ) @@ -58,8 +63,8 @@ def _run_generate(): "--bs", type=int, default=8, required=False, help="batch size: how many to summarize at a time", ) args = parser.parse_args() - lns = [" " + x.rstrip() for x in open(args.source_path).readlines()] - generate_summaries(lns, args.output_path, batch_size=args.bs, device=args.device) + examples = [" " + x.rstrip() for x in open(args.source_path).readlines()] + generate_summaries(examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device) if __name__ == "__main__": diff --git a/examples/summarization/bart/test_bart_examples.py b/examples/summarization/bart/test_bart_examples.py index 18064cc5d2..b1d1d8e756 100644 --- a/examples/summarization/bart/test_bart_examples.py +++ b/examples/summarization/bart/test_bart_examples.py @@ -25,7 +25,8 @@ class TestBartExamples(unittest.TestCase): tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo" with tmp.open("w") as f: f.write("\n".join(articles)) - testargs = ["evaluate_cnn.py", str(tmp), output_file_name] + + testargs = ["evaluate_cnn.py", str(tmp), output_file_name, "sshleifer/bart-tiny-random"] with patch.object(sys, "argv", testargs): _run_generate() self.assertTrue(Path(output_file_name).exists()) diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index af59de5aa3..10dfd5b6c8 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -27,7 +27,9 @@ from .utils import CACHE_DIR, require_torch, slow, torch_device if is_torch_available(): import torch from transformers import ( + AutoModel, AutoModelForSequenceClassification, + AutoTokenizer, BartModel, BartForConditionalGeneration, BartForSequenceClassification, @@ -183,6 +185,15 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase): def test_inputs_embeds(self): pass + def test_tiny_model(self): + model_name = "sshleifer/bart-tiny-random" + tiny = AutoModel.from_pretrained(model_name) # same vocab size + tok = AutoTokenizer.from_pretrained(model_name) # same tokenizer + inputs_dict = tok.batch_encode_plus(["Hello my friends"], return_tensors="pt") + + with torch.no_grad(): + tiny(**inputs_dict) + @require_torch class BartHeadTests(unittest.TestCase):