Fix marian tokenizer save pretrained (#5043)
This commit is contained in:
@@ -40,9 +40,9 @@ class MarianTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocab=None,
|
vocab,
|
||||||
source_spm=None,
|
source_spm,
|
||||||
target_spm=None,
|
target_spm,
|
||||||
source_lang=None,
|
source_lang=None,
|
||||||
target_lang=None,
|
target_lang=None,
|
||||||
unk_token="<unk>",
|
unk_token="<unk>",
|
||||||
@@ -59,6 +59,7 @@ class MarianTokenizer(PreTrainedTokenizer):
|
|||||||
pad_token=pad_token,
|
pad_token=pad_token,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
assert Path(source_spm).exists(), f"cannot find spm source {source_spm}"
|
||||||
self.encoder = load_json(vocab)
|
self.encoder = load_json(vocab)
|
||||||
if self.unk_token not in self.encoder:
|
if self.unk_token not in self.encoder:
|
||||||
raise KeyError("<unk> token must be in vocab")
|
raise KeyError("<unk> token must be in vocab")
|
||||||
@@ -179,10 +180,11 @@ class MarianTokenizer(PreTrainedTokenizer):
|
|||||||
assert save_dir.is_dir(), f"{save_directory} should be a directory"
|
assert save_dir.is_dir(), f"{save_directory} should be a directory"
|
||||||
save_json(self.encoder, save_dir / self.vocab_files_names["vocab"])
|
save_json(self.encoder, save_dir / self.vocab_files_names["vocab"])
|
||||||
|
|
||||||
for f in self.spm_files:
|
for orig, f in zip(["source.spm", "target.spm"], self.spm_files):
|
||||||
dest_path = save_dir / Path(f).name
|
dest_path = save_dir / Path(f).name
|
||||||
if not dest_path.exists():
|
if not dest_path.exists():
|
||||||
copyfile(f, save_dir / Path(f).name)
|
copyfile(f, save_dir / orig)
|
||||||
|
|
||||||
return tuple(save_dir / f for f in self.vocab_files_names)
|
return tuple(save_dir / f for f in self.vocab_files_names)
|
||||||
|
|
||||||
def get_vocab(self) -> Dict:
|
def get_vocab(self) -> Dict:
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
@@ -23,7 +24,6 @@ from transformers.tokenization_marian import MarianTokenizer, save_json, vocab_f
|
|||||||
from transformers.tokenization_utils import BatchEncoding
|
from transformers.tokenization_utils import BatchEncoding
|
||||||
|
|
||||||
from .test_tokenization_common import TokenizerTesterMixin
|
from .test_tokenization_common import TokenizerTesterMixin
|
||||||
from .utils import slow
|
|
||||||
|
|
||||||
|
|
||||||
SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
|
SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
|
||||||
@@ -60,10 +60,15 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
"This is a test",
|
"This is a test",
|
||||||
)
|
)
|
||||||
|
|
||||||
@slow
|
|
||||||
def test_tokenizer_equivalence_en_de(self):
|
def test_tokenizer_equivalence_en_de(self):
|
||||||
en_de_tokenizer = MarianTokenizer.from_pretrained(f"{ORG_NAME}opus-mt-en-de")
|
en_de_tokenizer = MarianTokenizer.from_pretrained(f"{ORG_NAME}opus-mt-en-de")
|
||||||
batch = en_de_tokenizer.prepare_translation_batch(["I am a small frog"], return_tensors=None)
|
batch = en_de_tokenizer.prepare_translation_batch(["I am a small frog"], return_tensors=None)
|
||||||
self.assertIsInstance(batch, BatchEncoding)
|
self.assertIsInstance(batch, BatchEncoding)
|
||||||
expected = [38, 121, 14, 697, 38848, 0]
|
expected = [38, 121, 14, 697, 38848, 0]
|
||||||
self.assertListEqual(expected, batch.input_ids[0])
|
self.assertListEqual(expected, batch.input_ids[0])
|
||||||
|
|
||||||
|
save_dir = tempfile.mkdtemp()
|
||||||
|
en_de_tokenizer.save_pretrained(save_dir)
|
||||||
|
contents = [x.name for x in Path(save_dir).glob("*")]
|
||||||
|
self.assertIn("source.spm", contents)
|
||||||
|
MarianTokenizer.from_pretrained(save_dir)
|
||||||
|
|||||||
Reference in New Issue
Block a user