Replace as_target context managers by direct calls (#18325)
* Preliminary work on tokenizers * Quality + fix tests * Treat processors * Fix pad * Remove all uses of in tests, docs and examples * Replace all as_target_tokenizer * Fix tests * Fix quality * Update examples/flax/image-captioning/run_image_captioning_flax.py Co-authored-by: amyeroberts <amy@huggingface.co> * Style Co-authored-by: amyeroberts <amy@huggingface.co>
This commit is contained in:
@@ -112,14 +112,13 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
|
||||
self.assertNotIn("decoder_attention_mask", batch)
|
||||
|
||||
@require_torch
|
||||
def test_as_target_tokenizer_target_length(self):
|
||||
def test_tokenizer_as_target_length(self):
|
||||
tgt_text = [
|
||||
"Summary of the text.",
|
||||
"Another summary.",
|
||||
]
|
||||
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
|
||||
with tokenizer.as_target_tokenizer():
|
||||
targets = tokenizer(tgt_text, max_length=32, padding="max_length", return_tensors="pt")
|
||||
targets = tokenizer(text_target=tgt_text, max_length=32, padding="max_length", return_tensors="pt")
|
||||
self.assertEqual(32, targets["input_ids"].shape[1])
|
||||
|
||||
@require_torch
|
||||
@@ -140,8 +139,7 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
|
||||
]
|
||||
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
|
||||
inputs = tokenizer(src_text, return_tensors="pt")
|
||||
with tokenizer.as_target_tokenizer():
|
||||
targets = tokenizer(tgt_text, return_tensors="pt")
|
||||
targets = tokenizer(text_target=tgt_text, return_tensors="pt")
|
||||
input_ids = inputs["input_ids"]
|
||||
labels = targets["input_ids"]
|
||||
self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item())
|
||||
|
||||
Reference in New Issue
Block a user