Deprecate prepare_seq2seq_batch (#10287)
* Deprecate prepare_seq2seq_batch * Fix last tests * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com> * More review comments Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
@@ -129,10 +129,7 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
|
||||
src_text = ["this is gunna be a long sentence " * 20]
|
||||
assert isinstance(src_text[0], str)
|
||||
desired_max_length = 10
|
||||
ids = self.tokenizer.prepare_seq2seq_batch(
|
||||
src_text,
|
||||
max_length=desired_max_length,
|
||||
).input_ids[0]
|
||||
ids = self.tokenizer(src_text, max_length=desired_max_length, truncation=True).input_ids[0]
|
||||
self.assertEqual(ids[0], EN_CODE)
|
||||
self.assertEqual(ids[-1], 2)
|
||||
self.assertEqual(len(ids), desired_max_length)
|
||||
@@ -147,32 +144,38 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
|
||||
new_tok = MBart50Tokenizer.from_pretrained(tmpdirname)
|
||||
self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens)
|
||||
|
||||
# prepare_seq2seq_batch tests below
|
||||
|
||||
@require_torch
|
||||
def test_batch_fairseq_parity(self):
|
||||
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, return_tensors="pt"
|
||||
)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||
batch = self.tokenizer(self.src_text, padding=True)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt")
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist()
|
||||
labels = labels.tolist()
|
||||
|
||||
for k in batch:
|
||||
batch[k] = batch[k].tolist()
|
||||
# batch = {k: v.tolist() for k,v in batch.items()}
|
||||
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
|
||||
# batch.decoder_inputs_ids[0][0] ==
|
||||
assert batch.input_ids[1][0] == EN_CODE
|
||||
assert batch.input_ids[1][-1] == 2
|
||||
assert batch.labels[1][0] == RO_CODE
|
||||
assert batch.labels[1][-1] == 2
|
||||
assert labels[1][0] == RO_CODE
|
||||
assert labels[1][-1] == 2
|
||||
assert batch.decoder_input_ids[1][:2] == [2, RO_CODE]
|
||||
|
||||
@require_torch
|
||||
def test_tokenizer_prepare_seq2seq_batch(self):
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens), return_tensors="pt"
|
||||
def test_tokenizer_prepare_batch(self):
|
||||
batch = self.tokenizer(
|
||||
self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
|
||||
)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(
|
||||
self.tgt_text,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=len(self.expected_src_tokens),
|
||||
return_tensors="pt",
|
||||
)
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
|
||||
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
|
||||
self.assertEqual((2, 14), batch.input_ids.shape)
|
||||
@@ -185,16 +188,11 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
||||
|
||||
def test_seq2seq_max_target_length(self):
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10, return_tensors="pt"
|
||||
)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
|
||||
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
||||
# max_target_length will default to max_length if not specified
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=3, return_tensors="pt"
|
||||
)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
|
||||
|
||||
Reference in New Issue
Block a user