From 3aca02efb3d4ff2d6d231c55d3b9367e61b7c0c4 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 9 Mar 2020 15:09:35 -0400 Subject: [PATCH] Bart example: model.to(device) (#3194) --- examples/summarization/bart/evaluate_cnn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/summarization/bart/evaluate_cnn.py b/examples/summarization/bart/evaluate_cnn.py index 7ce680ebe1..f35ddcb154 100644 --- a/examples/summarization/bart/evaluate_cnn.py +++ b/examples/summarization/bart/evaluate_cnn.py @@ -18,7 +18,7 @@ def chunks(lst, n): def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE): fout = Path(out_file).open("w") - model = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,) + model = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(device) tokenizer = BartTokenizer.from_pretrained("bart-large") for batch in tqdm(list(chunks(lns, batch_size))): dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True)