fix eos_token_ids in test
This commit is contained in:
@@ -61,7 +61,7 @@ class ModelTester:
|
|||||||
self.hidden_dropout_prob = 0.1
|
self.hidden_dropout_prob = 0.1
|
||||||
self.attention_probs_dropout_prob = 0.1
|
self.attention_probs_dropout_prob = 0.1
|
||||||
self.max_position_embeddings = 20
|
self.max_position_embeddings = 20
|
||||||
self.eos_token_id = 2
|
self.eos_token_ids = [2]
|
||||||
self.pad_token_id = 1
|
self.pad_token_id = 1
|
||||||
self.bos_token_id = 0
|
self.bos_token_id = 0
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
@@ -436,7 +436,7 @@ class BartModelIntegrationTest(unittest.TestCase):
|
|||||||
num_beams=4,
|
num_beams=4,
|
||||||
max_length=extra_len + 2,
|
max_length=extra_len + 2,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
decoder_start_token_id=hf.config.eos_token_id,
|
decoder_start_token_id=hf.config.eos_token_ids[0],
|
||||||
) # repetition_penalty=10.,
|
) # repetition_penalty=10.,
|
||||||
expected_result = "<s>The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday."
|
expected_result = "<s>The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday."
|
||||||
generated = [tok.decode(g,) for g in gen_tokens]
|
generated = [tok.decode(g,) for g in gen_tokens]
|
||||||
@@ -481,7 +481,7 @@ class BartModelIntegrationTest(unittest.TestCase):
|
|||||||
no_repeat_ngram_size=3,
|
no_repeat_ngram_size=3,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
early_stopping=True,
|
early_stopping=True,
|
||||||
decoder_start_token_id=hf.config.eos_token_id,
|
decoder_start_token_id=hf.config.eos_token_ids[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
decoded = [
|
decoded = [
|
||||||
|
|||||||
Reference in New Issue
Block a user