MBART: support summarization tasks where max_src_len > max_tgt_len (#6003)
* MBART: support summarization tasks * fix test * Style * add tokenizer test
This commit is contained in:
@@ -137,6 +137,18 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(self.tokenizer.prefix_tokens, [])
|
||||
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
|
||||
|
||||
def test_max_target_length(self):
|
||||
|
||||
batch = self.tokenizer.prepare_translation_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10
|
||||
)
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
||||
# max_target_length will default to max_length if not specified
|
||||
batch = self.tokenizer.prepare_translation_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3)
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
|
||||
|
||||
def test_enro_tokenizer_batch_encode_plus(self):
|
||||
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
|
||||
self.assertListEqual(self.expected_src_tokens, ids)
|
||||
|
||||
Reference in New Issue
Block a user