From 0b3d45eb64607158977f546d57f90eae268c7836 Mon Sep 17 00:00:00 2001 From: Stefan Schweter Date: Mon, 18 Nov 2019 15:49:44 +0100 Subject: [PATCH] camembert: add implementation for save_vocabulary method --- transformers/tokenization_camembert.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/transformers/tokenization_camembert.py b/transformers/tokenization_camembert.py index 41d3d74cff..bf2a6fe993 100644 --- a/transformers/tokenization_camembert.py +++ b/transformers/tokenization_camembert.py @@ -16,9 +16,14 @@ from __future__ import (absolute_import, division, print_function, unicode_literals) +import logging +import os +from shutil import copyfile + import sentencepiece as spm from transformers.tokenization_utils import PreTrainedTokenizer +logger = logging.getLogger(__name__) VOCAB_FILES_NAMES = {'vocab_file': 'sentencepiece.bpe.model'} @@ -55,6 +60,7 @@ class CamembertTokenizer(PreTrainedTokenizer): self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens self.sp_model = spm.SentencePieceProcessor() self.sp_model.Load(str(vocab_file)) + self.vocab_file = vocab_file # HACK: These tokens were added by fairseq but don't seem to be actually used when duplicated in the actual # sentencepiece vocabulary (this is the case for and self.fairseq_tokens_to_ids = {'NOTUSED': 0, '': 1, 'NOTUSED': 2, '': 3} @@ -135,3 +141,17 @@ class CamembertTokenizer(PreTrainedTokenizer): if index in self.fairseq_ids_to_tokens: return self.fairseq_ids_to_tokens[index] return self.sp_model.IdToPiece(index - self.fairseq_offset) + + def save_vocabulary(self, save_directory): + """ Save the sentencepiece vocabulary (copy original file) and special tokens file + to a directory. + """ + if not os.path.isdir(save_directory): + logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) + return + out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,)