From 3d495c61efbd2ca8a17827ff3103f7c820f0e9da Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 16 Jun 2020 09:48:19 -0400 Subject: [PATCH] Fix marian tokenizer save pretrained (#5043) --- src/transformers/tokenization_marian.py | 12 +++++++----- tests/test_tokenization_marian.py | 9 +++++++-- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/transformers/tokenization_marian.py b/src/transformers/tokenization_marian.py index 57307615e5..8a95d1fbbd 100644 --- a/src/transformers/tokenization_marian.py +++ b/src/transformers/tokenization_marian.py @@ -40,9 +40,9 @@ class MarianTokenizer(PreTrainedTokenizer): def __init__( self, - vocab=None, - source_spm=None, - target_spm=None, + vocab, + source_spm, + target_spm, source_lang=None, target_lang=None, unk_token="", @@ -59,6 +59,7 @@ class MarianTokenizer(PreTrainedTokenizer): pad_token=pad_token, **kwargs, ) + assert Path(source_spm).exists(), f"cannot find spm source {source_spm}" self.encoder = load_json(vocab) if self.unk_token not in self.encoder: raise KeyError(" token must be in vocab") @@ -179,10 +180,11 @@ class MarianTokenizer(PreTrainedTokenizer): assert save_dir.is_dir(), f"{save_directory} should be a directory" save_json(self.encoder, save_dir / self.vocab_files_names["vocab"]) - for f in self.spm_files: + for orig, f in zip(["source.spm", "target.spm"], self.spm_files): dest_path = save_dir / Path(f).name if not dest_path.exists(): - copyfile(f, save_dir / Path(f).name) + copyfile(f, save_dir / orig) + return tuple(save_dir / f for f in self.vocab_files_names) def get_vocab(self) -> Dict: diff --git a/tests/test_tokenization_marian.py b/tests/test_tokenization_marian.py index eea77e2b5d..91b4438a71 100644 --- a/tests/test_tokenization_marian.py +++ b/tests/test_tokenization_marian.py @@ -15,6 +15,7 @@ import os +import tempfile import unittest from pathlib import Path from shutil import copyfile @@ -23,7 +24,6 @@ from transformers.tokenization_marian import MarianTokenizer, save_json, vocab_f from transformers.tokenization_utils import BatchEncoding from .test_tokenization_common import TokenizerTesterMixin -from .utils import slow SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") @@ -60,10 +60,15 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase): "This is a test", ) - @slow def test_tokenizer_equivalence_en_de(self): en_de_tokenizer = MarianTokenizer.from_pretrained(f"{ORG_NAME}opus-mt-en-de") batch = en_de_tokenizer.prepare_translation_batch(["I am a small frog"], return_tensors=None) self.assertIsInstance(batch, BatchEncoding) expected = [38, 121, 14, 697, 38848, 0] self.assertListEqual(expected, batch.input_ids[0]) + + save_dir = tempfile.mkdtemp() + en_de_tokenizer.save_pretrained(save_dir) + contents = [x.name for x in Path(save_dir).glob("*")] + self.assertIn("source.spm", contents) + MarianTokenizer.from_pretrained(save_dir)