From 7044ed6b059c7305b0a1ab8576c775829afd9226 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 30 Aug 2019 17:36:11 +0200 Subject: [PATCH] fix tokenizers serialization --- pytorch_transformers/tests/tokenization_dilbert_test.py | 4 ++-- pytorch_transformers/tests/tokenization_tests_commons.py | 4 ++-- pytorch_transformers/tokenization_transfo_xl.py | 3 ++- pytorch_transformers/tokenization_xlnet.py | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/pytorch_transformers/tests/tokenization_dilbert_test.py b/pytorch_transformers/tests/tokenization_dilbert_test.py index 30268db216..42f8060998 100644 --- a/pytorch_transformers/tests/tokenization_dilbert_test.py +++ b/pytorch_transformers/tests/tokenization_dilbert_test.py @@ -27,8 +27,8 @@ class DistilBertTokenizationTest(BertTokenizationTest): tokenizer_class = DistilBertTokenizer - def get_tokenizer(self): - return DistilBertTokenizer.from_pretrained(self.tmpdirname) + def get_tokenizer(self, **kwargs): + return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs) def test_sequence_builders(self): tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") diff --git a/pytorch_transformers/tests/tokenization_tests_commons.py b/pytorch_transformers/tests/tokenization_tests_commons.py index 779a3ba6c3..6578c5c3a5 100644 --- a/pytorch_transformers/tests/tokenization_tests_commons.py +++ b/pytorch_transformers/tests/tokenization_tests_commons.py @@ -67,13 +67,13 @@ class CommonTestCases: with TemporaryDirectory() as tmpdirname: tokenizer.save_pretrained(tmpdirname) - tokenizer = tokenizer.from_pretrained(tmpdirname) + tokenizer = self.tokenizer_class.from_pretrained(tmpdirname) after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") self.assertListEqual(before_tokens, after_tokens) self.assertEqual(tokenizer.max_len, 42) - tokenizer = tokenizer.from_pretrained(tmpdirname, max_len=43) + tokenizer = self.tokenizer_class.from_pretrained(tmpdirname, max_len=43) self.assertEqual(tokenizer.max_len, 43) def test_pickle_tokenizer(self): diff --git a/pytorch_transformers/tokenization_transfo_xl.py b/pytorch_transformers/tokenization_transfo_xl.py index c603ba695c..66bc01c1bb 100644 --- a/pytorch_transformers/tokenization_transfo_xl.py +++ b/pytorch_transformers/tokenization_transfo_xl.py @@ -95,7 +95,8 @@ class TransfoXLTokenizer(PreTrainedTokenizer): # in a library like ours, at all. vocab_dict = torch.load(pretrained_vocab_file) for key, value in vocab_dict.items(): - self.__dict__[key] = value + if key not in self.__dict__: + self.__dict__[key] = value if vocab_file is not None: self.build_vocab() diff --git a/pytorch_transformers/tokenization_xlnet.py b/pytorch_transformers/tokenization_xlnet.py index ac7231bb68..bf9b9dc782 100644 --- a/pytorch_transformers/tokenization_xlnet.py +++ b/pytorch_transformers/tokenization_xlnet.py @@ -61,7 +61,7 @@ class XLNetTokenizer(PreTrainedTokenizer): pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - def __init__(self, vocab_file, max_len=None, + def __init__(self, vocab_file, do_lower_case=False, remove_space=True, keep_accents=False, bos_token="", eos_token="", unk_token="", sep_token="", pad_token="", cls_token="", mask_token="",