re-add eos token to get good bart results

This commit is contained in:
Patrick von Platen
2020-03-12 20:17:50 +01:00
parent c11160114a
commit 6047f46b19
2 changed files with 15 additions and 2 deletions

View File

@@ -432,7 +432,11 @@ class BartModelIntegrationTest(unittest.TestCase):
tokens = tok.encode(text, return_tensors="pt").to(torch_device)
extra_len = 20
gen_tokens = hf.generate(
tokens, num_beams=4, max_length=extra_len + 2, do_sample=False
tokens,
num_beams=4,
max_length=extra_len + 2,
do_sample=False,
decoder_start_token_id=hf.config.eos_token_id,
) # repetition_penalty=10.,
expected_result = "<s>The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday."
generated = [tok.decode(g,) for g in gen_tokens]
@@ -477,6 +481,7 @@ class BartModelIntegrationTest(unittest.TestCase):
no_repeat_ngram_size=3,
do_sample=False,
early_stopping=True,
decoder_start_token_id=hf.config.eos_token_id,
)
decoded = [