From f1c71da1154cb8bf58e07eccb4c1a3fcae83efb8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 12 Mar 2020 21:00:54 +0100 Subject: [PATCH] fix eos_token_ids in test --- tests/test_modeling_bart.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 8ad02feb2a..a0ad29830d 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -61,7 +61,7 @@ class ModelTester: self.hidden_dropout_prob = 0.1 self.attention_probs_dropout_prob = 0.1 self.max_position_embeddings = 20 - self.eos_token_id = 2 + self.eos_token_ids = [2] self.pad_token_id = 1 self.bos_token_id = 0 torch.manual_seed(0) @@ -436,7 +436,7 @@ class BartModelIntegrationTest(unittest.TestCase): num_beams=4, max_length=extra_len + 2, do_sample=False, - decoder_start_token_id=hf.config.eos_token_id, + decoder_start_token_id=hf.config.eos_token_ids[0], ) # repetition_penalty=10., expected_result = "The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday." generated = [tok.decode(g,) for g in gen_tokens] @@ -481,7 +481,7 @@ class BartModelIntegrationTest(unittest.TestCase): no_repeat_ngram_size=3, do_sample=False, early_stopping=True, - decoder_start_token_id=hf.config.eos_token_id, + decoder_start_token_id=hf.config.eos_token_ids[0], ) decoded = [