[MarianTokenizer] implement save_vocabulary and other common methods (#4389)
This commit is contained in:
@@ -1,7 +1,9 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Dict, List, Optional, Union
|
from pathlib import Path
|
||||||
|
from shutil import copyfile
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import sentencepiece
|
import sentencepiece
|
||||||
|
|
||||||
@@ -15,7 +17,7 @@ vocab_files_names = {
|
|||||||
"vocab": "vocab.json",
|
"vocab": "vocab.json",
|
||||||
"tokenizer_config_file": "tokenizer_config.json",
|
"tokenizer_config_file": "tokenizer_config.json",
|
||||||
}
|
}
|
||||||
MODEL_NAMES = ("opus-mt-en-de",) # TODO(SS): the only required constant is vocab_files_names
|
MODEL_NAMES = ("opus-mt-en-de",) # TODO(SS): delete this, the only required constant is vocab_files_names
|
||||||
PRETRAINED_VOCAB_FILES_MAP = {
|
PRETRAINED_VOCAB_FILES_MAP = {
|
||||||
k: {m: f"{S3_BUCKET_PREFIX}/Helsinki-NLP/{m}/{fname}" for m in MODEL_NAMES}
|
k: {m: f"{S3_BUCKET_PREFIX}/Helsinki-NLP/{m}/{fname}" for m in MODEL_NAMES}
|
||||||
for k, fname in vocab_files_names.items()
|
for k, fname in vocab_files_names.items()
|
||||||
@@ -55,14 +57,16 @@ class MarianTokenizer(PreTrainedTokenizer):
|
|||||||
eos_token="</s>",
|
eos_token="</s>",
|
||||||
pad_token="<pad>",
|
pad_token="<pad>",
|
||||||
max_len=512,
|
max_len=512,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
# bos_token=bos_token,
|
# bos_token=bos_token, unused. Start decoding with config.decoder_start_token_id
|
||||||
max_len=max_len,
|
max_len=max_len,
|
||||||
eos_token=eos_token,
|
eos_token=eos_token,
|
||||||
unk_token=unk_token,
|
unk_token=unk_token,
|
||||||
pad_token=pad_token,
|
pad_token=pad_token,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.encoder = load_json(vocab)
|
self.encoder = load_json(vocab)
|
||||||
if self.unk_token not in self.encoder:
|
if self.unk_token not in self.encoder:
|
||||||
@@ -72,21 +76,23 @@ class MarianTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
self.source_lang = source_lang
|
self.source_lang = source_lang
|
||||||
self.target_lang = target_lang
|
self.target_lang = target_lang
|
||||||
|
self.supported_language_codes: list = [k for k in self.encoder if k.startswith(">>") and k.endswith("<<")]
|
||||||
|
self.spm_files = [source_spm, target_spm]
|
||||||
|
|
||||||
# load SentencePiece model for pre-processing
|
# load SentencePiece model for pre-processing
|
||||||
self.spm_source = sentencepiece.SentencePieceProcessor()
|
self.spm_source = load_spm(source_spm)
|
||||||
self.spm_source.Load(source_spm)
|
self.spm_target = load_spm(target_spm)
|
||||||
|
self.current_spm = self.spm_source
|
||||||
self.spm_target = sentencepiece.SentencePieceProcessor()
|
|
||||||
self.spm_target.Load(target_spm)
|
|
||||||
|
|
||||||
# Multilingual target side: default to using first supported language code.
|
# Multilingual target side: default to using first supported language code.
|
||||||
self.supported_language_codes: list = [k for k in self.encoder if k.startswith(">>") and k.endswith("<<")]
|
|
||||||
|
|
||||||
|
self._setup_normalizer()
|
||||||
|
|
||||||
|
def _setup_normalizer(self):
|
||||||
try:
|
try:
|
||||||
from mosestokenizer import MosesPunctuationNormalizer
|
from mosestokenizer import MosesPunctuationNormalizer
|
||||||
|
|
||||||
self.punc_normalizer = MosesPunctuationNormalizer(source_lang)
|
self.punc_normalizer = MosesPunctuationNormalizer(self.source_lang)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
warnings.warn("Recommended: pip install mosestokenizer")
|
warnings.warn("Recommended: pip install mosestokenizer")
|
||||||
self.punc_normalizer = lambda x: x
|
self.punc_normalizer = lambda x: x
|
||||||
@@ -176,6 +182,65 @@ class MarianTokenizer(PreTrainedTokenizer):
|
|||||||
def vocab_size(self) -> int:
|
def vocab_size(self) -> int:
|
||||||
return len(self.encoder)
|
return len(self.encoder)
|
||||||
|
|
||||||
|
def save_vocabulary(self, save_directory: str) -> Tuple[str]:
|
||||||
|
"""save vocab file to json and copy spm files from their original path."""
|
||||||
|
save_dir = Path(save_directory)
|
||||||
|
assert save_dir.is_dir(), f"{save_directory} should be a directory"
|
||||||
|
save_json(self.encoder, save_dir / self.vocab_files_names["vocab"])
|
||||||
|
|
||||||
|
for f in self.spm_files:
|
||||||
|
dest_path = save_dir / Path(f).name
|
||||||
|
if not dest_path.exists():
|
||||||
|
copyfile(f, save_dir / Path(f).name)
|
||||||
|
return tuple(save_dir / f for f in self.vocab_files_names)
|
||||||
|
|
||||||
|
def get_vocab(self) -> Dict:
|
||||||
|
vocab = self.encoder.copy()
|
||||||
|
vocab.update(self.added_tokens_encoder)
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
def __getstate__(self) -> Dict:
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
state.update({k: None for k in ["spm_source", "spm_target", "current_spm", "punc_normalizer"]})
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, d: Dict) -> None:
|
||||||
|
self.__dict__ = d
|
||||||
|
self.spm_source, self.spm_target = (load_spm(f) for f in self.spm_files)
|
||||||
|
self.current_spm = self.spm_source
|
||||||
|
self._setup_normalizer()
|
||||||
|
|
||||||
|
def num_special_tokens_to_add(self, **unused):
|
||||||
|
"""Just EOS"""
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def _special_token_mask(self, seq):
|
||||||
|
all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp
|
||||||
|
all_special_ids.remove(self.unk_token_id) # <unk> is only sometimes special
|
||||||
|
return [1 if x in all_special_ids else 0 for x in seq]
|
||||||
|
|
||||||
|
def get_special_tokens_mask(
|
||||||
|
self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
|
||||||
|
) -> List[int]:
|
||||||
|
"""Get list where entries are [1] if a token is [eos] or [pad] else 0."""
|
||||||
|
if already_has_special_tokens:
|
||||||
|
return self._special_token_mask(token_ids_0)
|
||||||
|
elif token_ids_1 is None:
|
||||||
|
return self._special_token_mask(token_ids_0) + [1]
|
||||||
|
else:
|
||||||
|
return self._special_token_mask(token_ids_0 + token_ids_1) + [1]
|
||||||
|
|
||||||
|
|
||||||
|
def load_spm(path: str) -> sentencepiece.SentencePieceProcessor:
|
||||||
|
spm = sentencepiece.SentencePieceProcessor()
|
||||||
|
spm.Load(path)
|
||||||
|
return spm
|
||||||
|
|
||||||
|
|
||||||
|
def save_json(data, path: str) -> None:
|
||||||
|
with open(path, "w") as f:
|
||||||
|
json.dump(data, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
def load_json(path: str) -> Union[Dict, List]:
|
def load_json(path: str) -> Union[Dict, List]:
|
||||||
with open(path, "r") as f:
|
with open(path, "r") as f:
|
||||||
|
|||||||
@@ -129,11 +129,6 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
|
|||||||
max_indices = logits.argmax(-1)
|
max_indices = logits.argmax(-1)
|
||||||
self.tokenizer.batch_decode(max_indices)
|
self.tokenizer.batch_decode(max_indices)
|
||||||
|
|
||||||
def test_tokenizer_equivalence(self):
|
|
||||||
batch = self.tokenizer.prepare_translation_batch(["I am a small frog"]).to(torch_device)
|
|
||||||
expected = [38, 121, 14, 697, 38848, 0]
|
|
||||||
self.assertListEqual(expected, batch.input_ids[0].tolist())
|
|
||||||
|
|
||||||
def test_unk_support(self):
|
def test_unk_support(self):
|
||||||
t = self.tokenizer
|
t = self.tokenizer
|
||||||
ids = t.prepare_translation_batch(["||"]).to(torch_device).input_ids[0].tolist()
|
ids = t.prepare_translation_batch(["||"]).to(torch_device).input_ids[0].tolist()
|
||||||
|
|||||||
70
tests/test_tokenization_marian.py
Normal file
70
tests/test_tokenization_marian.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2020 Huggingface
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
from shutil import copyfile
|
||||||
|
|
||||||
|
from transformers.tokenization_marian import MarianTokenizer, save_json, vocab_files_names
|
||||||
|
from transformers.tokenization_utils import BatchEncoding
|
||||||
|
|
||||||
|
from .test_tokenization_common import TokenizerTesterMixin
|
||||||
|
from .utils import slow
|
||||||
|
|
||||||
|
|
||||||
|
SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
|
||||||
|
|
||||||
|
mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"}
|
||||||
|
zh_code = ">>zh<<"
|
||||||
|
ORG_NAME = "Helsinki-NLP/"
|
||||||
|
|
||||||
|
|
||||||
|
class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
|
tokenizer_class = MarianTokenizer
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
vocab = ["</s>", "<unk>", "▁This", "▁is", "▁a", "▁t", "est", "\u0120", "<pad>"]
|
||||||
|
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||||
|
save_dir = Path(self.tmpdirname)
|
||||||
|
save_json(vocab_tokens, save_dir / vocab_files_names["vocab"])
|
||||||
|
save_json(mock_tokenizer_config, save_dir / vocab_files_names["tokenizer_config_file"])
|
||||||
|
if not (save_dir / vocab_files_names["source_spm"]).exists():
|
||||||
|
copyfile(SAMPLE_SP, save_dir / vocab_files_names["source_spm"])
|
||||||
|
copyfile(SAMPLE_SP, save_dir / vocab_files_names["target_spm"])
|
||||||
|
|
||||||
|
tokenizer = MarianTokenizer.from_pretrained(self.tmpdirname)
|
||||||
|
tokenizer.save_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
|
def get_tokenizer(self, max_len=None, **kwargs) -> MarianTokenizer:
|
||||||
|
# overwrite max_len=512 default
|
||||||
|
return MarianTokenizer.from_pretrained(self.tmpdirname, max_len=max_len, **kwargs)
|
||||||
|
|
||||||
|
def get_input_output_texts(self):
|
||||||
|
return (
|
||||||
|
"This is a test",
|
||||||
|
"This is a test",
|
||||||
|
)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_tokenizer_equivalence_en_de(self):
|
||||||
|
en_de_tokenizer = MarianTokenizer.from_pretrained(f"{ORG_NAME}opus-mt-en-de")
|
||||||
|
batch = en_de_tokenizer.prepare_translation_batch(["I am a small frog"], return_tensors=None)
|
||||||
|
self.assertIsInstance(batch, BatchEncoding)
|
||||||
|
expected = [38, 121, 14, 697, 38848, 0]
|
||||||
|
self.assertListEqual(expected, batch.input_ids[0])
|
||||||
Reference in New Issue
Block a user