From e02ed0ee7e1b500010452b569087f4e6ddd1f800 Mon Sep 17 00:00:00 2001 From: Benjamin Davidson Date: Thu, 16 Sep 2021 21:30:05 +0100 Subject: [PATCH] XLMR tokenizer is fully picklable (#13577) * made tokenizer fully picklable * remove whitespace * added testcase --- .../models/xlm_roberta/tokenization_xlm_roberta.py | 3 ++- tests/test_tokenization_xlm_roberta.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py b/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py index 32f0cfaed8..78a56615eb 100644 --- a/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py @@ -171,6 +171,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer): def __getstate__(self): state = self.__dict__.copy() state["sp_model"] = None + state["sp_model_proto"] = self.sp_model.serialized_model_proto() return state def __setstate__(self, d): @@ -181,7 +182,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer): self.sp_model_kwargs = {} self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) - self.sp_model.Load(self.vocab_file) + self.sp_model.LoadFromSerializedProto(self.sp_model_proto) def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None diff --git a/tests/test_tokenization_xlm_roberta.py b/tests/test_tokenization_xlm_roberta.py index 3604395e6f..d25782ae19 100644 --- a/tests/test_tokenization_xlm_roberta.py +++ b/tests/test_tokenization_xlm_roberta.py @@ -14,6 +14,9 @@ # limitations under the License. import os +import pickle +import shutil +import tempfile import unittest from transformers import SPIECE_UNDERLINE, XLMRobertaTokenizer, XLMRobertaTokenizerFast @@ -141,6 +144,13 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): def big_tokenizer(self): return XLMRobertaTokenizer.from_pretrained("xlm-roberta-base") + def test_picklable_without_disk(self): + with tempfile.NamedTemporaryFile() as f: + shutil.copyfile(SAMPLE_VOCAB, f.name) + tokenizer = XLMRobertaTokenizer(f.name, keep_accents=True) + pickled_tokenizer = pickle.dumps(tokenizer) + pickle.loads(pickled_tokenizer) + def test_rust_and_python_full_tokenizers(self): if not self.test_rust_tokenizer: return