Deprecate prepare_seq2seq_batch (#10287)
* Deprecate prepare_seq2seq_batch * Fix last tests * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com> * More review comments Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
@@ -354,9 +354,7 @@ class MarianIntegrationTest(unittest.TestCase):
|
||||
self.assertListEqual(self.expected_text, generated_words)
|
||||
|
||||
def translate_src_text(self, **tokenizer_kwargs):
|
||||
model_inputs = self.tokenizer.prepare_seq2seq_batch(
|
||||
src_texts=self.src_text, return_tensors="pt", **tokenizer_kwargs
|
||||
).to(torch_device)
|
||||
model_inputs = self.tokenizer(self.src_text, return_tensors="pt", **tokenizer_kwargs).to(torch_device)
|
||||
self.assertEqual(self.model.device, model_inputs.input_ids.device)
|
||||
generated_ids = self.model.generate(
|
||||
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128
|
||||
@@ -373,9 +371,10 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
|
||||
src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."]
|
||||
expected_ids = [38, 121, 14, 697, 38848, 0]
|
||||
|
||||
model_inputs: dict = self.tokenizer.prepare_seq2seq_batch(src, tgt_texts=tgt, return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
model_inputs = self.tokenizer(src, return_tensors="pt").to(torch_device)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(tgt, return_tensors="pt")
|
||||
model_inputs["labels"] = targets["input_ids"].to(torch_device)
|
||||
|
||||
self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist())
|
||||
|
||||
@@ -397,16 +396,12 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
|
||||
|
||||
def test_unk_support(self):
|
||||
t = self.tokenizer
|
||||
ids = t.prepare_seq2seq_batch(["||"], return_tensors="pt").to(torch_device).input_ids[0].tolist()
|
||||
ids = t(["||"], return_tensors="pt").to(torch_device).input_ids[0].tolist()
|
||||
expected = [t.unk_token_id, t.unk_token_id, t.eos_token_id]
|
||||
self.assertEqual(expected, ids)
|
||||
|
||||
def test_pad_not_split(self):
|
||||
input_ids_w_pad = (
|
||||
self.tokenizer.prepare_seq2seq_batch(["I am a small frog <pad>"], return_tensors="pt")
|
||||
.input_ids[0]
|
||||
.tolist()
|
||||
)
|
||||
input_ids_w_pad = self.tokenizer(["I am a small frog <pad>"], return_tensors="pt").input_ids[0].tolist()
|
||||
expected_w_pad = [38, 121, 14, 697, 38848, self.tokenizer.pad_token_id, 0] # pad
|
||||
self.assertListEqual(expected_w_pad, input_ids_w_pad)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user