add sp_model_kwargs to unpickle of xlm roberta tok (#11430)
add test for pickle simplify test fix test code style add missing pickle import fix test fix test fix test
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import pickle
|
||||
import unittest
|
||||
|
||||
from transformers import SPIECE_UNDERLINE, XLMRobertaTokenizer, XLMRobertaTokenizerFast
|
||||
@@ -142,6 +143,18 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assertFalse(all_equal)
|
||||
|
||||
def test_pickle_subword_regularization_tokenizer(self):
|
||||
"""Google pickle __getstate__ __setstate__ if you are struggling with this."""
|
||||
# Subword regularization is only available for the slow tokenizer.
|
||||
sp_model_kwargs = {"enable_sampling": True, "alpha": 0.1, "nbest_size": -1}
|
||||
tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB, keep_accents=True, sp_model_kwargs=sp_model_kwargs)
|
||||
tokenizer_bin = pickle.dumps(tokenizer)
|
||||
tokenizer_new = pickle.loads(tokenizer_bin)
|
||||
|
||||
self.assertIsNotNone(tokenizer_new.sp_model_kwargs)
|
||||
self.assertTrue(isinstance(tokenizer_new.sp_model_kwargs, dict))
|
||||
self.assertEqual(tokenizer_new.sp_model_kwargs, sp_model_kwargs)
|
||||
|
||||
@cached_property
|
||||
def big_tokenizer(self):
|
||||
return XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
|
||||
|
||||
Reference in New Issue
Block a user