update file to new starting token logic
This commit is contained in:
@@ -20,6 +20,10 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
|
|||||||
fout = Path(out_file).open("w")
|
fout = Path(out_file).open("w")
|
||||||
model = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(device)
|
model = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(device)
|
||||||
tokenizer = BartTokenizer.from_pretrained("bart-large")
|
tokenizer = BartTokenizer.from_pretrained("bart-large")
|
||||||
|
|
||||||
|
max_length = 140
|
||||||
|
min_length = 55
|
||||||
|
|
||||||
for batch in tqdm(list(chunks(lns, batch_size))):
|
for batch in tqdm(list(chunks(lns, batch_size))):
|
||||||
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True)
|
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True)
|
||||||
summaries = model.generate(
|
summaries = model.generate(
|
||||||
@@ -27,11 +31,12 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
|
|||||||
attention_mask=dct["attention_mask"].to(device),
|
attention_mask=dct["attention_mask"].to(device),
|
||||||
num_beams=4,
|
num_beams=4,
|
||||||
length_penalty=2.0,
|
length_penalty=2.0,
|
||||||
max_length=142, # +2 from original because we start at step=1 and stop before max_length
|
max_length=max_length + 2, # +2 from original because we start at step=1 and stop before max_length
|
||||||
min_length=56, # +1 from original because we start at step=1
|
min_length=min_length + 1, # +1 from original because we start at step=1
|
||||||
no_repeat_ngram_size=3,
|
no_repeat_ngram_size=3,
|
||||||
early_stopping=True,
|
early_stopping=True,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
|
decoder_start_token_id=model.config.eos_token_ids[0]
|
||||||
)
|
)
|
||||||
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
|
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
|
||||||
for hypothesis in dec:
|
for hypothesis in dec:
|
||||||
|
|||||||
Reference in New Issue
Block a user