From e0db8276a635b674553cf38aa23699c6340cffd6 Mon Sep 17 00:00:00 2001 From: Philip May Date: Fri, 30 Apr 2021 09:44:58 +0200 Subject: [PATCH] add sp_model_kwargs to unpickle of xlm roberta tok (#11430) add test for pickle simplify test fix test code style add missing pickle import fix test fix test fix test --- .../models/xlm_roberta/tokenization_xlm_roberta.py | 13 +++++++++---- tests/test_tokenization_xlm_roberta.py | 13 +++++++++++++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py b/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py index cda78e900d..9241c4f470 100644 --- a/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py @@ -135,7 +135,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer): # Mask token behave like a normal word, i.e. include the space before it mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token - sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs super().__init__( bos_token=bos_token, @@ -145,11 +145,11 @@ class XLMRobertaTokenizer(PreTrainedTokenizer): cls_token=cls_token, pad_token=pad_token, mask_token=mask_token, - sp_model_kwargs=sp_model_kwargs, + sp_model_kwargs=self.sp_model_kwargs, **kwargs, ) - self.sp_model = spm.SentencePieceProcessor(**sp_model_kwargs) + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) self.sp_model.Load(str(vocab_file)) self.vocab_file = vocab_file @@ -175,7 +175,12 @@ class XLMRobertaTokenizer(PreTrainedTokenizer): def __setstate__(self, d): self.__dict__ = d - self.sp_model = spm.SentencePieceProcessor() + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) self.sp_model.Load(self.vocab_file) def build_inputs_with_special_tokens( diff --git a/tests/test_tokenization_xlm_roberta.py b/tests/test_tokenization_xlm_roberta.py index 8031ebc405..b9fe4dde62 100644 --- a/tests/test_tokenization_xlm_roberta.py +++ b/tests/test_tokenization_xlm_roberta.py @@ -16,6 +16,7 @@ import itertools import os +import pickle import unittest from transformers import SPIECE_UNDERLINE, XLMRobertaTokenizer, XLMRobertaTokenizerFast @@ -142,6 +143,18 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): self.assertFalse(all_equal) + def test_pickle_subword_regularization_tokenizer(self): + """Google pickle __getstate__ __setstate__ if you are struggling with this.""" + # Subword regularization is only available for the slow tokenizer. + sp_model_kwargs = {"enable_sampling": True, "alpha": 0.1, "nbest_size": -1} + tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB, keep_accents=True, sp_model_kwargs=sp_model_kwargs) + tokenizer_bin = pickle.dumps(tokenizer) + tokenizer_new = pickle.loads(tokenizer_bin) + + self.assertIsNotNone(tokenizer_new.sp_model_kwargs) + self.assertTrue(isinstance(tokenizer_new.sp_model_kwargs, dict)) + self.assertEqual(tokenizer_new.sp_model_kwargs, sp_model_kwargs) + @cached_property def big_tokenizer(self): return XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")