Replace as_target context managers by direct calls (#18325)
* Preliminary work on tokenizers * Quality + fix tests * Treat processors * Fix pad * Remove all uses of in tests, docs and examples * Replace all as_target_tokenizer * Fix tests * Fix quality * Update examples/flax/image-captioning/run_image_captioning_flax.py Co-authored-by: amyeroberts <amy@huggingface.co> * Style Co-authored-by: amyeroberts <amy@huggingface.co>
This commit is contained in:
@@ -187,9 +187,7 @@ class M2M100TokenizerIntegrationTest(unittest.TestCase):
|
||||
self.tokenizer.src_lang = "en"
|
||||
self.tokenizer.tgt_lang = "fr"
|
||||
|
||||
batch = self.tokenizer(self.src_text, padding=True, return_tensors="pt")
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
batch["labels"] = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt").input_ids
|
||||
batch = self.tokenizer(self.src_text, text_target=self.tgt_text, padding=True, return_tensors="pt")
|
||||
|
||||
batch["decoder_input_ids"] = shift_tokens_right(
|
||||
batch["labels"], self.tokenizer.pad_token_id, self.tokenizer.eos_token_id
|
||||
@@ -217,17 +215,19 @@ class M2M100TokenizerIntegrationTest(unittest.TestCase):
|
||||
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
||||
|
||||
@require_torch
|
||||
def test_as_target_tokenizer(self):
|
||||
def test_tokenizer_target_mode(self):
|
||||
self.tokenizer.tgt_lang = "mr"
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("mr")])
|
||||
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
||||
self.tokenizer._switch_to_target_mode()
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("mr")])
|
||||
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
||||
self.tokenizer._switch_to_input_mode()
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id(self.tokenizer.src_lang)])
|
||||
|
||||
self.tokenizer.tgt_lang = "zh"
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("zh")])
|
||||
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
||||
self.tokenizer._switch_to_target_mode()
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("zh")])
|
||||
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
||||
self.tokenizer._switch_to_input_mode()
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id(self.tokenizer.src_lang)])
|
||||
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user