Tokenizers should be framework agnostic (#8599)
* Tokenizers should be framework agnostic * Run the slow tests * Not testing * Fix documentation * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -132,9 +132,9 @@ class MarianIntegrationTest(unittest.TestCase):
|
||||
self.assertListEqual(self.expected_text, generated_words)
|
||||
|
||||
def translate_src_text(self, **tokenizer_kwargs):
|
||||
model_inputs = self.tokenizer.prepare_seq2seq_batch(src_texts=self.src_text, **tokenizer_kwargs).to(
|
||||
torch_device
|
||||
)
|
||||
model_inputs = self.tokenizer.prepare_seq2seq_batch(
|
||||
src_texts=self.src_text, return_tensors="pt", **tokenizer_kwargs
|
||||
).to(torch_device)
|
||||
self.assertEqual(self.model.device, model_inputs.input_ids.device)
|
||||
generated_ids = self.model.generate(
|
||||
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128
|
||||
@@ -151,7 +151,9 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
|
||||
src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."]
|
||||
expected_ids = [38, 121, 14, 697, 38848, 0]
|
||||
|
||||
model_inputs: dict = self.tokenizer.prepare_seq2seq_batch(src, tgt_texts=tgt).to(torch_device)
|
||||
model_inputs: dict = self.tokenizer.prepare_seq2seq_batch(src, tgt_texts=tgt, return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist())
|
||||
|
||||
@@ -171,12 +173,16 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
|
||||
|
||||
def test_unk_support(self):
|
||||
t = self.tokenizer
|
||||
ids = t.prepare_seq2seq_batch(["||"]).to(torch_device).input_ids[0].tolist()
|
||||
ids = t.prepare_seq2seq_batch(["||"], return_tensors="pt").to(torch_device).input_ids[0].tolist()
|
||||
expected = [t.unk_token_id, t.unk_token_id, t.eos_token_id]
|
||||
self.assertEqual(expected, ids)
|
||||
|
||||
def test_pad_not_split(self):
|
||||
input_ids_w_pad = self.tokenizer.prepare_seq2seq_batch(["I am a small frog <pad>"]).input_ids[0].tolist()
|
||||
input_ids_w_pad = (
|
||||
self.tokenizer.prepare_seq2seq_batch(["I am a small frog <pad>"], return_tensors="pt")
|
||||
.input_ids[0]
|
||||
.tolist()
|
||||
)
|
||||
expected_w_pad = [38, 121, 14, 697, 38848, self.tokenizer.pad_token_id, 0] # pad
|
||||
self.assertListEqual(expected_w_pad, input_ids_w_pad)
|
||||
|
||||
@@ -294,7 +300,7 @@ class TestMarian_en_ROMANCE(MarianIntegrationTest):
|
||||
normalized = self.tokenizer.normalize("")
|
||||
self.assertIsInstance(normalized, str)
|
||||
with self.assertRaises(ValueError):
|
||||
self.tokenizer.prepare_seq2seq_batch([""])
|
||||
self.tokenizer.prepare_seq2seq_batch([""], return_tensors="pt")
|
||||
|
||||
@slow
|
||||
def test_pipeline(self):
|
||||
|
||||
@@ -92,7 +92,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
@slow
|
||||
def test_enro_generate_one(self):
|
||||
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
|
||||
["UN Chief Says There Is No Military Solution in Syria"]
|
||||
["UN Chief Says There Is No Military Solution in Syria"], return_tensors="pt"
|
||||
).to(torch_device)
|
||||
translated_tokens = self.model.generate(**batch)
|
||||
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||
@@ -101,7 +101,9 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
|
||||
@slow
|
||||
def test_enro_generate_batch(self):
|
||||
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(self.src_text).to(torch_device)
|
||||
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(self.src_text, return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
translated_tokens = self.model.generate(**batch)
|
||||
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||
assert self.tgt_text == decoded
|
||||
@@ -153,7 +155,7 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
|
||||
@unittest.skip("This test is broken, still generates english")
|
||||
def test_cc25_generate(self):
|
||||
inputs = self.tokenizer.prepare_seq2seq_batch([self.src_text[0]]).to(torch_device)
|
||||
inputs = self.tokenizer.prepare_seq2seq_batch([self.src_text[0]], return_tensors="pt").to(torch_device)
|
||||
translated_tokens = self.model.generate(
|
||||
input_ids=inputs["input_ids"].to(torch_device),
|
||||
decoder_start_token_id=self.tokenizer.lang_code_to_id["ro_RO"],
|
||||
@@ -163,7 +165,9 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
|
||||
@slow
|
||||
def test_fill_mask(self):
|
||||
inputs = self.tokenizer.prepare_seq2seq_batch(["One of the best <mask> I ever read!"]).to(torch_device)
|
||||
inputs = self.tokenizer.prepare_seq2seq_batch(["One of the best <mask> I ever read!"], return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
outputs = self.model.generate(
|
||||
inputs["input_ids"], decoder_start_token_id=self.tokenizer.lang_code_to_id["en_XX"], num_beams=1
|
||||
)
|
||||
|
||||
@@ -1794,7 +1794,7 @@ class TokenizerTesterMixin:
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.labels.shape[1], 10)
|
||||
# max_target_length will default to max_length if not specified
|
||||
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, max_length=3)
|
||||
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, max_length=3, return_tensors="pt")
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.labels.shape[1], 3)
|
||||
|
||||
|
||||
@@ -165,7 +165,6 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
desired_max_length = 10
|
||||
ids = self.tokenizer.prepare_seq2seq_batch(
|
||||
src_text,
|
||||
return_tensors=None,
|
||||
max_length=desired_max_length,
|
||||
).input_ids[0]
|
||||
self.assertEqual(ids[-2], 2)
|
||||
@@ -203,9 +202,7 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
@require_torch
|
||||
def test_enro_tokenizer_prepare_seq2seq_batch(self):
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text,
|
||||
tgt_texts=self.tgt_text,
|
||||
max_length=len(self.expected_src_tokens),
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens), return_tensors="pt"
|
||||
)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
@@ -221,13 +218,15 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
|
||||
def test_seq2seq_max_target_length(self):
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10, return_tensors="pt"
|
||||
)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
||||
# max_target_length will default to max_length if not specified
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3)
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=3, return_tensors="pt"
|
||||
)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
|
||||
|
||||
@@ -61,7 +61,9 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
def test_pegasus_large_seq2seq_truncation(self):
|
||||
src_texts = ["This is going to be way too long." * 150, "short example"]
|
||||
tgt_texts = ["not super long but more than 5 tokens", "tiny"]
|
||||
batch = self.pegasus_large_tokenizer.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts, max_target_length=5)
|
||||
batch = self.pegasus_large_tokenizer.prepare_seq2seq_batch(
|
||||
src_texts, tgt_texts=tgt_texts, max_target_length=5, return_tensors="pt"
|
||||
)
|
||||
assert batch.input_ids.shape == (2, 1024)
|
||||
assert batch.attention_mask.shape == (2, 1024)
|
||||
assert "labels" in batch # because tgt_texts was specified
|
||||
|
||||
Reference in New Issue
Block a user