From fb3b22c3b932cfbaffa217f1d036ab3ce668add2 Mon Sep 17 00:00:00 2001 From: Yuchao Dai <3407450+icyblade@users.noreply.github.com> Date: Thu, 6 Jul 2023 17:21:27 +0800 Subject: [PATCH] LlamaTokenizer should be picklable (#24681) * LlamaTokenizer should be picklable * make fixup --- src/transformers/models/llama/tokenization_llama.py | 3 ++- tests/models/llama/test_tokenization_llama.py | 8 ++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/llama/tokenization_llama.py b/src/transformers/models/llama/tokenization_llama.py index 13f093ae94..50b7c7a262 100644 --- a/src/transformers/models/llama/tokenization_llama.py +++ b/src/transformers/models/llama/tokenization_llama.py @@ -98,12 +98,13 @@ class LlamaTokenizer(PreTrainedTokenizer): def __getstate__(self): state = self.__dict__.copy() state["sp_model"] = None + state["sp_model_proto"] = self.sp_model.serialized_model_proto() return state def __setstate__(self, d): self.__dict__ = d self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) - self.sp_model.Load(self.vocab_file) + self.sp_model.LoadFromSerializedProto(self.sp_model_proto) @property def vocab_size(self): diff --git a/tests/models/llama/test_tokenization_llama.py b/tests/models/llama/test_tokenization_llama.py index 3a1ec2be93..67d287fac1 100644 --- a/tests/models/llama/test_tokenization_llama.py +++ b/tests/models/llama/test_tokenization_llama.py @@ -14,6 +14,7 @@ # limitations under the License. import os +import pickle import shutil import tempfile import unittest @@ -285,6 +286,13 @@ class LlamaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): padding=False, ) + def test_picklable(self): + with tempfile.NamedTemporaryFile() as f: + shutil.copyfile(SAMPLE_VOCAB, f.name) + tokenizer = LlamaTokenizer(f.name, keep_accents=True) + pickled_tokenizer = pickle.dumps(tokenizer) + pickle.loads(pickled_tokenizer) + @require_torch @require_sentencepiece