camembert: add implementation for save_vocabulary method
This commit is contained in:
@@ -16,9 +16,14 @@
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
|
||||
import logging
|
||||
import os
|
||||
from shutil import copyfile
|
||||
|
||||
import sentencepiece as spm
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {'vocab_file': 'sentencepiece.bpe.model'}
|
||||
|
||||
@@ -55,6 +60,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
|
||||
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
|
||||
self.sp_model = spm.SentencePieceProcessor()
|
||||
self.sp_model.Load(str(vocab_file))
|
||||
self.vocab_file = vocab_file
|
||||
# HACK: These tokens were added by fairseq but don't seem to be actually used when duplicated in the actual
|
||||
# sentencepiece vocabulary (this is the case for <s> and </s>
|
||||
self.fairseq_tokens_to_ids = {'<s>NOTUSED': 0, '<pad>': 1, '</s>NOTUSED': 2, '<unk>': 3}
|
||||
@@ -135,3 +141,17 @@ class CamembertTokenizer(PreTrainedTokenizer):
|
||||
if index in self.fairseq_ids_to_tokens:
|
||||
return self.fairseq_ids_to_tokens[index]
|
||||
return self.sp_model.IdToPiece(index - self.fairseq_offset)
|
||||
|
||||
def save_vocabulary(self, save_directory):
|
||||
""" Save the sentencepiece vocabulary (copy original file) and special tokens file
|
||||
to a directory.
|
||||
"""
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||
return
|
||||
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
||||
Reference in New Issue
Block a user