Fix marian tokenizer save pretrained (#5043)
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
@@ -23,7 +24,6 @@ from transformers.tokenization_marian import MarianTokenizer, save_json, vocab_f
|
||||
from transformers.tokenization_utils import BatchEncoding
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
from .utils import slow
|
||||
|
||||
|
||||
SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
|
||||
@@ -60,10 +60,15 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
"This is a test",
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_tokenizer_equivalence_en_de(self):
|
||||
en_de_tokenizer = MarianTokenizer.from_pretrained(f"{ORG_NAME}opus-mt-en-de")
|
||||
batch = en_de_tokenizer.prepare_translation_batch(["I am a small frog"], return_tensors=None)
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
expected = [38, 121, 14, 697, 38848, 0]
|
||||
self.assertListEqual(expected, batch.input_ids[0])
|
||||
|
||||
save_dir = tempfile.mkdtemp()
|
||||
en_de_tokenizer.save_pretrained(save_dir)
|
||||
contents = [x.name for x in Path(save_dir).glob("*")]
|
||||
self.assertIn("source.spm", contents)
|
||||
MarianTokenizer.from_pretrained(save_dir)
|
||||
|
||||
Reference in New Issue
Block a user