add sp_model_kwargs to unpickle of xlm roberta tok (#11430)
add test for pickle simplify test fix test code style add missing pickle import fix test fix test fix test
This commit is contained in:
@@ -135,7 +135,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
|
|||||||
# Mask token behave like a normal word, i.e. include the space before it
|
# Mask token behave like a normal word, i.e. include the space before it
|
||||||
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
||||||
|
|
||||||
sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
bos_token=bos_token,
|
bos_token=bos_token,
|
||||||
@@ -145,11 +145,11 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
|
|||||||
cls_token=cls_token,
|
cls_token=cls_token,
|
||||||
pad_token=pad_token,
|
pad_token=pad_token,
|
||||||
mask_token=mask_token,
|
mask_token=mask_token,
|
||||||
sp_model_kwargs=sp_model_kwargs,
|
sp_model_kwargs=self.sp_model_kwargs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.sp_model = spm.SentencePieceProcessor(**sp_model_kwargs)
|
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||||
self.sp_model.Load(str(vocab_file))
|
self.sp_model.Load(str(vocab_file))
|
||||||
self.vocab_file = vocab_file
|
self.vocab_file = vocab_file
|
||||||
|
|
||||||
@@ -175,7 +175,12 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
def __setstate__(self, d):
|
def __setstate__(self, d):
|
||||||
self.__dict__ = d
|
self.__dict__ = d
|
||||||
self.sp_model = spm.SentencePieceProcessor()
|
|
||||||
|
# for backward compatibility
|
||||||
|
if not hasattr(self, "sp_model_kwargs"):
|
||||||
|
self.sp_model_kwargs = {}
|
||||||
|
|
||||||
|
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||||
self.sp_model.Load(self.vocab_file)
|
self.sp_model.Load(self.vocab_file)
|
||||||
|
|
||||||
def build_inputs_with_special_tokens(
|
def build_inputs_with_special_tokens(
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import SPIECE_UNDERLINE, XLMRobertaTokenizer, XLMRobertaTokenizerFast
|
from transformers import SPIECE_UNDERLINE, XLMRobertaTokenizer, XLMRobertaTokenizerFast
|
||||||
@@ -142,6 +143,18 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.assertFalse(all_equal)
|
self.assertFalse(all_equal)
|
||||||
|
|
||||||
|
def test_pickle_subword_regularization_tokenizer(self):
|
||||||
|
"""Google pickle __getstate__ __setstate__ if you are struggling with this."""
|
||||||
|
# Subword regularization is only available for the slow tokenizer.
|
||||||
|
sp_model_kwargs = {"enable_sampling": True, "alpha": 0.1, "nbest_size": -1}
|
||||||
|
tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB, keep_accents=True, sp_model_kwargs=sp_model_kwargs)
|
||||||
|
tokenizer_bin = pickle.dumps(tokenizer)
|
||||||
|
tokenizer_new = pickle.loads(tokenizer_bin)
|
||||||
|
|
||||||
|
self.assertIsNotNone(tokenizer_new.sp_model_kwargs)
|
||||||
|
self.assertTrue(isinstance(tokenizer_new.sp_model_kwargs, dict))
|
||||||
|
self.assertEqual(tokenizer_new.sp_model_kwargs, sp_model_kwargs)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def big_tokenizer(self):
|
def big_tokenizer(self):
|
||||||
return XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
|
return XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
|
||||||
|
|||||||
Reference in New Issue
Block a user