Rename BartForMaskedLM -> BartForConditionalGeneration (#3114)
* improved documentation
This commit is contained in:
@@ -4,7 +4,7 @@ from pathlib import Path
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import BartForMaskedLM, BartTokenizer
|
||||
from transformers import BartForConditionalGeneration, BartTokenizer
|
||||
|
||||
|
||||
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@@ -18,7 +18,7 @@ def chunks(lst, n):
|
||||
|
||||
def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
|
||||
fout = Path(out_file).open("w")
|
||||
model = BartForMaskedLM.from_pretrained("bart-large-cnn", output_past=True,)
|
||||
model = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,)
|
||||
tokenizer = BartTokenizer.from_pretrained("bart-large")
|
||||
for batch in tqdm(list(chunks(lns, batch_size))):
|
||||
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True)
|
||||
|
||||
Reference in New Issue
Block a user