Tokenizers should be framework agnostic (#8599)
* Tokenizers should be framework agnostic * Run the slow tests * Not testing * Fix documentation * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -165,7 +165,6 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
desired_max_length = 10
|
||||
ids = self.tokenizer.prepare_seq2seq_batch(
|
||||
src_text,
|
||||
return_tensors=None,
|
||||
max_length=desired_max_length,
|
||||
).input_ids[0]
|
||||
self.assertEqual(ids[-2], 2)
|
||||
@@ -203,9 +202,7 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
@require_torch
|
||||
def test_enro_tokenizer_prepare_seq2seq_batch(self):
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text,
|
||||
tgt_texts=self.tgt_text,
|
||||
max_length=len(self.expected_src_tokens),
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens), return_tensors="pt"
|
||||
)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
@@ -221,13 +218,15 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
|
||||
def test_seq2seq_max_target_length(self):
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10, return_tensors="pt"
|
||||
)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
||||
# max_target_length will default to max_length if not specified
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3)
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=3, return_tensors="pt"
|
||||
)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
|
||||
|
||||
Reference in New Issue
Block a user