XLMR tokenizer is fully picklable (#13577)
* made tokenizer fully picklable * remove whitespace * added testcase
This commit is contained in:
committed by
GitHub
parent
af5c6ae5ed
commit
e02ed0ee7e
@@ -171,6 +171,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
|
|||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
state = self.__dict__.copy()
|
state = self.__dict__.copy()
|
||||||
state["sp_model"] = None
|
state["sp_model"] = None
|
||||||
|
state["sp_model_proto"] = self.sp_model.serialized_model_proto()
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def __setstate__(self, d):
|
def __setstate__(self, d):
|
||||||
@@ -181,7 +182,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
|
|||||||
self.sp_model_kwargs = {}
|
self.sp_model_kwargs = {}
|
||||||
|
|
||||||
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||||
self.sp_model.Load(self.vocab_file)
|
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
|
||||||
|
|
||||||
def build_inputs_with_special_tokens(
|
def build_inputs_with_special_tokens(
|
||||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||||
|
|||||||
@@ -14,6 +14,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import SPIECE_UNDERLINE, XLMRobertaTokenizer, XLMRobertaTokenizerFast
|
from transformers import SPIECE_UNDERLINE, XLMRobertaTokenizer, XLMRobertaTokenizerFast
|
||||||
@@ -141,6 +144,13 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
def big_tokenizer(self):
|
def big_tokenizer(self):
|
||||||
return XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
|
return XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
|
||||||
|
|
||||||
|
def test_picklable_without_disk(self):
|
||||||
|
with tempfile.NamedTemporaryFile() as f:
|
||||||
|
shutil.copyfile(SAMPLE_VOCAB, f.name)
|
||||||
|
tokenizer = XLMRobertaTokenizer(f.name, keep_accents=True)
|
||||||
|
pickled_tokenizer = pickle.dumps(tokenizer)
|
||||||
|
pickle.loads(pickled_tokenizer)
|
||||||
|
|
||||||
def test_rust_and_python_full_tokenizers(self):
|
def test_rust_and_python_full_tokenizers(self):
|
||||||
if not self.test_rust_tokenizer:
|
if not self.test_rust_tokenizer:
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user