camembert: add implementation for save_vocabulary method
This commit is contained in:
@@ -16,9 +16,14 @@
|
|||||||
from __future__ import (absolute_import, division, print_function,
|
from __future__ import (absolute_import, division, print_function,
|
||||||
unicode_literals)
|
unicode_literals)
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from shutil import copyfile
|
||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
VOCAB_FILES_NAMES = {'vocab_file': 'sentencepiece.bpe.model'}
|
VOCAB_FILES_NAMES = {'vocab_file': 'sentencepiece.bpe.model'}
|
||||||
|
|
||||||
@@ -55,6 +60,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
|
|||||||
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
|
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
|
||||||
self.sp_model = spm.SentencePieceProcessor()
|
self.sp_model = spm.SentencePieceProcessor()
|
||||||
self.sp_model.Load(str(vocab_file))
|
self.sp_model.Load(str(vocab_file))
|
||||||
|
self.vocab_file = vocab_file
|
||||||
# HACK: These tokens were added by fairseq but don't seem to be actually used when duplicated in the actual
|
# HACK: These tokens were added by fairseq but don't seem to be actually used when duplicated in the actual
|
||||||
# sentencepiece vocabulary (this is the case for <s> and </s>
|
# sentencepiece vocabulary (this is the case for <s> and </s>
|
||||||
self.fairseq_tokens_to_ids = {'<s>NOTUSED': 0, '<pad>': 1, '</s>NOTUSED': 2, '<unk>': 3}
|
self.fairseq_tokens_to_ids = {'<s>NOTUSED': 0, '<pad>': 1, '</s>NOTUSED': 2, '<unk>': 3}
|
||||||
@@ -135,3 +141,17 @@ class CamembertTokenizer(PreTrainedTokenizer):
|
|||||||
if index in self.fairseq_ids_to_tokens:
|
if index in self.fairseq_ids_to_tokens:
|
||||||
return self.fairseq_ids_to_tokens[index]
|
return self.fairseq_ids_to_tokens[index]
|
||||||
return self.sp_model.IdToPiece(index - self.fairseq_offset)
|
return self.sp_model.IdToPiece(index - self.fairseq_offset)
|
||||||
|
|
||||||
|
def save_vocabulary(self, save_directory):
|
||||||
|
""" Save the sentencepiece vocabulary (copy original file) and special tokens file
|
||||||
|
to a directory.
|
||||||
|
"""
|
||||||
|
if not os.path.isdir(save_directory):
|
||||||
|
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||||
|
return
|
||||||
|
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
|
||||||
|
|
||||||
|
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||||
|
copyfile(self.vocab_file, out_vocab_file)
|
||||||
|
|
||||||
|
return (out_vocab_file,)
|
||||||
|
|||||||
Reference in New Issue
Block a user