diff --git a/examples/summarization/bart/evaluate_cnn.py b/examples/summarization/bart/evaluate_cnn.py index 474fcd46cb..93590614eb 100644 --- a/examples/summarization/bart/evaluate_cnn.py +++ b/examples/summarization/bart/evaluate_cnn.py @@ -20,6 +20,10 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE): fout = Path(out_file).open("w") model = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(device) tokenizer = BartTokenizer.from_pretrained("bart-large") + + max_length = 140 + min_length = 55 + for batch in tqdm(list(chunks(lns, batch_size))): dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True) summaries = model.generate( @@ -27,11 +31,12 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE): attention_mask=dct["attention_mask"].to(device), num_beams=4, length_penalty=2.0, - max_length=142, # +2 from original because we start at step=1 and stop before max_length - min_length=56, # +1 from original because we start at step=1 + max_length=max_length + 2, # +2 from original because we start at step=1 and stop before max_length + min_length=min_length + 1, # +1 from original because we start at step=1 no_repeat_ngram_size=3, early_stopping=True, do_sample=False, + decoder_start_token_id=model.config.eos_token_ids[0] ) dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries] for hypothesis in dec: