finalize generation merge

This commit is contained in:
Patrick von Platen
2020-03-11 11:53:36 +01:00
parent 1ba21f96ca
commit a332cc9f7f
4 changed files with 10 additions and 13 deletions

View File

@@ -82,7 +82,7 @@ class ModelTester:
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
eos_token_ids=self.eos_token_id,
eos_token_ids=[self.eos_token_id],
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
)
@@ -214,7 +214,7 @@ class BartHeadTests(unittest.TestCase):
decoder_ffn_dim=32,
max_position_embeddings=48,
output_past=output_past,
eos_token_id=2,
eos_token_ids=[2],
pad_token_id=1,
bos_token_id=0,
)
@@ -276,7 +276,7 @@ class BartHeadTests(unittest.TestCase):
decoder_ffn_dim=32,
max_position_embeddings=48,
output_past=True,
eos_token_ids=2,
eos_token_ids=[2],
pad_token_id=1,
bos_token_id=0,
)
@@ -287,7 +287,7 @@ class BartHeadTests(unittest.TestCase):
new_input_ids = lm_model.generate(
input_ids.clone(), num_return_sequences=1, num_beams=2, no_repeat_ngram_size=3, max_length=max_length
)
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length))
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length - 1))
# TODO(SS): uneven length batches, empty inputs
def test_shift_tokens_right(self):