fixed typo

This commit is contained in:
Patrick von Platen
2020-03-09 20:01:20 +01:00
parent a2c8e516c2
commit 374deef48d
2 changed files with 2 additions and 2 deletions

View File

@@ -287,7 +287,7 @@ class BartHeadTests(unittest.TestCase):
new_input_ids = lm_model.generate(
input_ids.clone(), num_return_sequences=1, num_beams=2, no_repeat_ngram_size=3, max_length=max_length
)
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length - 1))
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length))
# TODO(SS): uneven length batches, empty inputs
def test_shift_tokens_right(self):