Replace as_target context managers by direct calls (#18325)
* Preliminary work on tokenizers * Quality + fix tests * Treat processors * Fix pad * Remove all uses of in tests, docs and examples * Replace all as_target_tokenizer * Fix tests * Fix quality * Update examples/flax/image-captioning/run_image_captioning_flax.py Co-authored-by: amyeroberts <amy@huggingface.co> * Style Co-authored-by: amyeroberts <amy@huggingface.co>
This commit is contained in:
@@ -438,10 +438,7 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
|
||||
src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."]
|
||||
expected_ids = [38, 121, 14, 697, 38848, 0]
|
||||
|
||||
model_inputs = self.tokenizer(src, return_tensors="pt").to(torch_device)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(tgt, return_tensors="pt")
|
||||
model_inputs["labels"] = targets["input_ids"].to(torch_device)
|
||||
model_inputs = self.tokenizer(src, text_target=tgt, return_tensors="pt").to(torch_device)
|
||||
|
||||
self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist())
|
||||
|
||||
|
||||
@@ -145,9 +145,8 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
src_ids = tokenizer(source_text).input_ids
|
||||
self.assertListEqual(src_ids, expected_src_ids)
|
||||
|
||||
with tokenizer.as_target_tokenizer():
|
||||
target_ids = tokenizer(target_text).input_ids
|
||||
self.assertListEqual(target_ids, expected_target_ids)
|
||||
target_ids = tokenizer(text_target=target_text).input_ids
|
||||
self.assertListEqual(target_ids, expected_target_ids)
|
||||
|
||||
decoded = tokenizer.decode(target_ids, skip_special_tokens=True)
|
||||
self.assertEqual(decoded, target_text)
|
||||
|
||||
Reference in New Issue
Block a user