Replace as_target context managers by direct calls (#18325)
* Preliminary work on tokenizers * Quality + fix tests * Treat processors * Fix pad * Remove all uses of in tests, docs and examples * Replace all as_target_tokenizer * Fix tests * Fix quality * Update examples/flax/image-captioning/run_image_captioning_flax.py Co-authored-by: amyeroberts <amy@huggingface.co> * Style Co-authored-by: amyeroberts <amy@huggingface.co>
This commit is contained in:
@@ -299,33 +299,26 @@ class PLBartPythonEnIntegrationTest(unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
def test_batch_fairseq_parity(self):
|
||||
batch = self.tokenizer(self.src_text, padding=True)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt")
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist()
|
||||
batch = self.tokenizer(self.src_text, text_target=self.tgt_text, padding=True, return_tensors="pt")
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)
|
||||
|
||||
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
|
||||
self.assertEqual(batch.input_ids[1][-2:], [2, PYTHON_CODE])
|
||||
self.assertEqual(batch.input_ids[1][-2:].tolist(), [2, PYTHON_CODE])
|
||||
self.assertEqual(batch.decoder_input_ids[1][0], EN_CODE)
|
||||
self.assertEqual(batch.decoder_input_ids[1][-1], 2)
|
||||
self.assertEqual(labels[1][-2:].tolist(), [2, EN_CODE])
|
||||
self.assertEqual(batch.labels[1][-2:].tolist(), [2, EN_CODE])
|
||||
|
||||
@require_torch
|
||||
def test_python_en_tokenizer_prepare_batch(self):
|
||||
batch = self.tokenizer(
|
||||
self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
|
||||
self.src_text,
|
||||
text_target=self.tgt_text,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=len(self.expected_src_tokens),
|
||||
return_tensors="pt",
|
||||
)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(
|
||||
self.tgt_text,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=len(self.expected_src_tokens),
|
||||
return_tensors="pt",
|
||||
)
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)
|
||||
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
|
||||
@@ -340,8 +333,9 @@ class PLBartPythonEnIntegrationTest(unittest.TestCase):
|
||||
|
||||
def test_seq2seq_max_length(self):
|
||||
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
|
||||
targets = self.tokenizer(
|
||||
text_target=self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt"
|
||||
)
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user