Tokenizers should be framework agnostic (#8599)
* Tokenizers should be framework agnostic * Run the slow tests * Not testing * Fix documentation * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -92,7 +92,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
@slow
|
||||
def test_enro_generate_one(self):
|
||||
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
|
||||
["UN Chief Says There Is No Military Solution in Syria"]
|
||||
["UN Chief Says There Is No Military Solution in Syria"], return_tensors="pt"
|
||||
).to(torch_device)
|
||||
translated_tokens = self.model.generate(**batch)
|
||||
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||
@@ -101,7 +101,9 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
|
||||
@slow
|
||||
def test_enro_generate_batch(self):
|
||||
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(self.src_text).to(torch_device)
|
||||
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(self.src_text, return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
translated_tokens = self.model.generate(**batch)
|
||||
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||
assert self.tgt_text == decoded
|
||||
@@ -153,7 +155,7 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
|
||||
@unittest.skip("This test is broken, still generates english")
|
||||
def test_cc25_generate(self):
|
||||
inputs = self.tokenizer.prepare_seq2seq_batch([self.src_text[0]]).to(torch_device)
|
||||
inputs = self.tokenizer.prepare_seq2seq_batch([self.src_text[0]], return_tensors="pt").to(torch_device)
|
||||
translated_tokens = self.model.generate(
|
||||
input_ids=inputs["input_ids"].to(torch_device),
|
||||
decoder_start_token_id=self.tokenizer.lang_code_to_id["ro_RO"],
|
||||
@@ -163,7 +165,9 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
|
||||
@slow
|
||||
def test_fill_mask(self):
|
||||
inputs = self.tokenizer.prepare_seq2seq_batch(["One of the best <mask> I ever read!"]).to(torch_device)
|
||||
inputs = self.tokenizer.prepare_seq2seq_batch(["One of the best <mask> I ever read!"], return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
outputs = self.model.generate(
|
||||
inputs["input_ids"], decoder_start_token_id=self.tokenizer.lang_code_to_id["en_XX"], num_beams=1
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user