[MarianTokenizer] implement save_vocabulary and other common methods (#4389)

This commit is contained in:
Sam Shleifer
2020-05-19 19:45:49 -04:00
committed by GitHub
parent 956c4c4eb4
commit efbc1c5a9d
3 changed files with 145 additions and 15 deletions

View File

@@ -1,7 +1,9 @@
import json
import re
import warnings
from typing import Dict, List, Optional, Union
from pathlib import Path
from shutil import copyfile
from typing import Dict, List, Optional, Tuple, Union
import sentencepiece
@@ -15,7 +17,7 @@ vocab_files_names = {
"vocab": "vocab.json",
"tokenizer_config_file": "tokenizer_config.json",
}
MODEL_NAMES = ("opus-mt-en-de",) # TODO(SS): the only required constant is vocab_files_names
MODEL_NAMES = ("opus-mt-en-de",) # TODO(SS): delete this, the only required constant is vocab_files_names
PRETRAINED_VOCAB_FILES_MAP = {
k: {m: f"{S3_BUCKET_PREFIX}/Helsinki-NLP/{m}/{fname}" for m in MODEL_NAMES}
for k, fname in vocab_files_names.items()
@@ -55,14 +57,16 @@ class MarianTokenizer(PreTrainedTokenizer):
eos_token="</s>",
pad_token="<pad>",
max_len=512,
**kwargs,
):
super().__init__(
# bos_token=bos_token,
# bos_token=bos_token, unused. Start decoding with config.decoder_start_token_id
max_len=max_len,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
**kwargs,
)
self.encoder = load_json(vocab)
if self.unk_token not in self.encoder:
@@ -72,21 +76,23 @@ class MarianTokenizer(PreTrainedTokenizer):
self.source_lang = source_lang
self.target_lang = target_lang
self.supported_language_codes: list = [k for k in self.encoder if k.startswith(">>") and k.endswith("<<")]
self.spm_files = [source_spm, target_spm]
# load SentencePiece model for pre-processing
self.spm_source = sentencepiece.SentencePieceProcessor()
self.spm_source.Load(source_spm)
self.spm_target = sentencepiece.SentencePieceProcessor()
self.spm_target.Load(target_spm)
self.spm_source = load_spm(source_spm)
self.spm_target = load_spm(target_spm)
self.current_spm = self.spm_source
# Multilingual target side: default to using first supported language code.
self.supported_language_codes: list = [k for k in self.encoder if k.startswith(">>") and k.endswith("<<")]
self._setup_normalizer()
def _setup_normalizer(self):
try:
from mosestokenizer import MosesPunctuationNormalizer
self.punc_normalizer = MosesPunctuationNormalizer(source_lang)
self.punc_normalizer = MosesPunctuationNormalizer(self.source_lang)
except ImportError:
warnings.warn("Recommended: pip install mosestokenizer")
self.punc_normalizer = lambda x: x
@@ -176,6 +182,65 @@ class MarianTokenizer(PreTrainedTokenizer):
def vocab_size(self) -> int:
return len(self.encoder)
def save_vocabulary(self, save_directory: str) -> Tuple[str]:
"""save vocab file to json and copy spm files from their original path."""
save_dir = Path(save_directory)
assert save_dir.is_dir(), f"{save_directory} should be a directory"
save_json(self.encoder, save_dir / self.vocab_files_names["vocab"])
for f in self.spm_files:
dest_path = save_dir / Path(f).name
if not dest_path.exists():
copyfile(f, save_dir / Path(f).name)
return tuple(save_dir / f for f in self.vocab_files_names)
def get_vocab(self) -> Dict:
vocab = self.encoder.copy()
vocab.update(self.added_tokens_encoder)
return vocab
def __getstate__(self) -> Dict:
state = self.__dict__.copy()
state.update({k: None for k in ["spm_source", "spm_target", "current_spm", "punc_normalizer"]})
return state
def __setstate__(self, d: Dict) -> None:
self.__dict__ = d
self.spm_source, self.spm_target = (load_spm(f) for f in self.spm_files)
self.current_spm = self.spm_source
self._setup_normalizer()
def num_special_tokens_to_add(self, **unused):
"""Just EOS"""
return 1
def _special_token_mask(self, seq):
all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp
all_special_ids.remove(self.unk_token_id) # <unk> is only sometimes special
return [1 if x in all_special_ids else 0 for x in seq]
def get_special_tokens_mask(
self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""Get list where entries are [1] if a token is [eos] or [pad] else 0."""
if already_has_special_tokens:
return self._special_token_mask(token_ids_0)
elif token_ids_1 is None:
return self._special_token_mask(token_ids_0) + [1]
else:
return self._special_token_mask(token_ids_0 + token_ids_1) + [1]
def load_spm(path: str) -> sentencepiece.SentencePieceProcessor:
spm = sentencepiece.SentencePieceProcessor()
spm.Load(path)
return spm
def save_json(data, path: str) -> None:
with open(path, "w") as f:
json.dump(data, f, indent=2)
def load_json(path: str) -> Union[Dict, List]:
with open(path, "r") as f: