diff --git a/examples/summarization/bart/evaluate_cnn.py b/examples/summarization/bart/evaluate_cnn.py index 7ce680ebe1..f35ddcb154 100644 --- a/examples/summarization/bart/evaluate_cnn.py +++ b/examples/summarization/bart/evaluate_cnn.py @@ -18,7 +18,7 @@ def chunks(lst, n): def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE): fout = Path(out_file).open("w") - model = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,) + model = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(device) tokenizer = BartTokenizer.from_pretrained("bart-large") for batch in tqdm(list(chunks(lns, batch_size))): dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True)