From efbc1c5a9d96048ab11f8d746fe51107cb91646f Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 19 May 2020 19:45:49 -0400 Subject: [PATCH] [MarianTokenizer] implement save_vocabulary and other common methods (#4389) --- src/transformers/tokenization_marian.py | 85 ++++++++++++++++++++++--- tests/test_modeling_marian.py | 5 -- tests/test_tokenization_marian.py | 70 ++++++++++++++++++++ 3 files changed, 145 insertions(+), 15 deletions(-) create mode 100644 tests/test_tokenization_marian.py diff --git a/src/transformers/tokenization_marian.py b/src/transformers/tokenization_marian.py index cb2dab5248..4203fec09d 100644 --- a/src/transformers/tokenization_marian.py +++ b/src/transformers/tokenization_marian.py @@ -1,7 +1,9 @@ import json import re import warnings -from typing import Dict, List, Optional, Union +from pathlib import Path +from shutil import copyfile +from typing import Dict, List, Optional, Tuple, Union import sentencepiece @@ -15,7 +17,7 @@ vocab_files_names = { "vocab": "vocab.json", "tokenizer_config_file": "tokenizer_config.json", } -MODEL_NAMES = ("opus-mt-en-de",) # TODO(SS): the only required constant is vocab_files_names +MODEL_NAMES = ("opus-mt-en-de",) # TODO(SS): delete this, the only required constant is vocab_files_names PRETRAINED_VOCAB_FILES_MAP = { k: {m: f"{S3_BUCKET_PREFIX}/Helsinki-NLP/{m}/{fname}" for m in MODEL_NAMES} for k, fname in vocab_files_names.items() @@ -55,14 +57,16 @@ class MarianTokenizer(PreTrainedTokenizer): eos_token="", pad_token="", max_len=512, + **kwargs, ): super().__init__( - # bos_token=bos_token, + # bos_token=bos_token, unused. Start decoding with config.decoder_start_token_id max_len=max_len, eos_token=eos_token, unk_token=unk_token, pad_token=pad_token, + **kwargs, ) self.encoder = load_json(vocab) if self.unk_token not in self.encoder: @@ -72,21 +76,23 @@ class MarianTokenizer(PreTrainedTokenizer): self.source_lang = source_lang self.target_lang = target_lang + self.supported_language_codes: list = [k for k in self.encoder if k.startswith(">>") and k.endswith("<<")] + self.spm_files = [source_spm, target_spm] # load SentencePiece model for pre-processing - self.spm_source = sentencepiece.SentencePieceProcessor() - self.spm_source.Load(source_spm) - - self.spm_target = sentencepiece.SentencePieceProcessor() - self.spm_target.Load(target_spm) + self.spm_source = load_spm(source_spm) + self.spm_target = load_spm(target_spm) + self.current_spm = self.spm_source # Multilingual target side: default to using first supported language code. - self.supported_language_codes: list = [k for k in self.encoder if k.startswith(">>") and k.endswith("<<")] + self._setup_normalizer() + + def _setup_normalizer(self): try: from mosestokenizer import MosesPunctuationNormalizer - self.punc_normalizer = MosesPunctuationNormalizer(source_lang) + self.punc_normalizer = MosesPunctuationNormalizer(self.source_lang) except ImportError: warnings.warn("Recommended: pip install mosestokenizer") self.punc_normalizer = lambda x: x @@ -176,6 +182,65 @@ class MarianTokenizer(PreTrainedTokenizer): def vocab_size(self) -> int: return len(self.encoder) + def save_vocabulary(self, save_directory: str) -> Tuple[str]: + """save vocab file to json and copy spm files from their original path.""" + save_dir = Path(save_directory) + assert save_dir.is_dir(), f"{save_directory} should be a directory" + save_json(self.encoder, save_dir / self.vocab_files_names["vocab"]) + + for f in self.spm_files: + dest_path = save_dir / Path(f).name + if not dest_path.exists(): + copyfile(f, save_dir / Path(f).name) + return tuple(save_dir / f for f in self.vocab_files_names) + + def get_vocab(self) -> Dict: + vocab = self.encoder.copy() + vocab.update(self.added_tokens_encoder) + return vocab + + def __getstate__(self) -> Dict: + state = self.__dict__.copy() + state.update({k: None for k in ["spm_source", "spm_target", "current_spm", "punc_normalizer"]}) + return state + + def __setstate__(self, d: Dict) -> None: + self.__dict__ = d + self.spm_source, self.spm_target = (load_spm(f) for f in self.spm_files) + self.current_spm = self.spm_source + self._setup_normalizer() + + def num_special_tokens_to_add(self, **unused): + """Just EOS""" + return 1 + + def _special_token_mask(self, seq): + all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp + all_special_ids.remove(self.unk_token_id) # is only sometimes special + return [1 if x in all_special_ids else 0 for x in seq] + + def get_special_tokens_mask( + self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """Get list where entries are [1] if a token is [eos] or [pad] else 0.""" + if already_has_special_tokens: + return self._special_token_mask(token_ids_0) + elif token_ids_1 is None: + return self._special_token_mask(token_ids_0) + [1] + else: + return self._special_token_mask(token_ids_0 + token_ids_1) + [1] + + +def load_spm(path: str) -> sentencepiece.SentencePieceProcessor: + spm = sentencepiece.SentencePieceProcessor() + spm.Load(path) + return spm + + +def save_json(data, path: str) -> None: + with open(path, "w") as f: + json.dump(data, f, indent=2) + def load_json(path: str) -> Union[Dict, List]: with open(path, "r") as f: diff --git a/tests/test_modeling_marian.py b/tests/test_modeling_marian.py index 3858d273ab..c1a1f4a96c 100644 --- a/tests/test_modeling_marian.py +++ b/tests/test_modeling_marian.py @@ -129,11 +129,6 @@ class TestMarian_EN_DE_More(MarianIntegrationTest): max_indices = logits.argmax(-1) self.tokenizer.batch_decode(max_indices) - def test_tokenizer_equivalence(self): - batch = self.tokenizer.prepare_translation_batch(["I am a small frog"]).to(torch_device) - expected = [38, 121, 14, 697, 38848, 0] - self.assertListEqual(expected, batch.input_ids[0].tolist()) - def test_unk_support(self): t = self.tokenizer ids = t.prepare_translation_batch(["||"]).to(torch_device).input_ids[0].tolist() diff --git a/tests/test_tokenization_marian.py b/tests/test_tokenization_marian.py new file mode 100644 index 0000000000..688413af82 --- /dev/null +++ b/tests/test_tokenization_marian.py @@ -0,0 +1,70 @@ +# coding=utf-8 +# Copyright 2020 Huggingface +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import unittest +from pathlib import Path +from shutil import copyfile + +from transformers.tokenization_marian import MarianTokenizer, save_json, vocab_files_names +from transformers.tokenization_utils import BatchEncoding + +from .test_tokenization_common import TokenizerTesterMixin +from .utils import slow + + +SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") + +mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"} +zh_code = ">>zh<<" +ORG_NAME = "Helsinki-NLP/" + + +class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase): + + tokenizer_class = MarianTokenizer + + def setUp(self): + super().setUp() + vocab = ["", "", "▁This", "▁is", "▁a", "▁t", "est", "\u0120", ""] + vocab_tokens = dict(zip(vocab, range(len(vocab)))) + save_dir = Path(self.tmpdirname) + save_json(vocab_tokens, save_dir / vocab_files_names["vocab"]) + save_json(mock_tokenizer_config, save_dir / vocab_files_names["tokenizer_config_file"]) + if not (save_dir / vocab_files_names["source_spm"]).exists(): + copyfile(SAMPLE_SP, save_dir / vocab_files_names["source_spm"]) + copyfile(SAMPLE_SP, save_dir / vocab_files_names["target_spm"]) + + tokenizer = MarianTokenizer.from_pretrained(self.tmpdirname) + tokenizer.save_pretrained(self.tmpdirname) + + def get_tokenizer(self, max_len=None, **kwargs) -> MarianTokenizer: + # overwrite max_len=512 default + return MarianTokenizer.from_pretrained(self.tmpdirname, max_len=max_len, **kwargs) + + def get_input_output_texts(self): + return ( + "This is a test", + "This is a test", + ) + + @slow + def test_tokenizer_equivalence_en_de(self): + en_de_tokenizer = MarianTokenizer.from_pretrained(f"{ORG_NAME}opus-mt-en-de") + batch = en_de_tokenizer.prepare_translation_batch(["I am a small frog"], return_tensors=None) + self.assertIsInstance(batch, BatchEncoding) + expected = [38, 121, 14, 697, 38848, 0] + self.assertListEqual(expected, batch.input_ids[0])