From 69da972ace6fd574a528ef269ebcee32305d18ff Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 30 Aug 2019 17:09:36 +0200 Subject: [PATCH] added test and debug tokenizer configuration serialization --- .../tests/tokenization_bert_test.py | 4 ++-- .../tests/tokenization_gpt2_test.py | 5 +++-- .../tests/tokenization_openai_test.py | 4 ++-- .../tests/tokenization_roberta_test.py | 5 +++-- .../tests/tokenization_tests_commons.py | 15 ++++++++++++--- .../tests/tokenization_transfo_xl_test.py | 5 +++-- .../tests/tokenization_xlm_test.py | 4 ++-- .../tests/tokenization_xlnet_test.py | 4 ++-- pytorch_transformers/tokenization_utils.py | 4 +++- 9 files changed, 32 insertions(+), 18 deletions(-) diff --git a/pytorch_transformers/tests/tokenization_bert_test.py b/pytorch_transformers/tests/tokenization_bert_test.py index db507317a8..290b357820 100644 --- a/pytorch_transformers/tests/tokenization_bert_test.py +++ b/pytorch_transformers/tests/tokenization_bert_test.py @@ -41,8 +41,8 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) - def get_tokenizer(self): - return BertTokenizer.from_pretrained(self.tmpdirname) + def get_tokenizer(self, **kwargs): + return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs) def get_input_output_texts(self): input_text = u"UNwant\u00E9d,running" diff --git a/pytorch_transformers/tests/tokenization_gpt2_test.py b/pytorch_transformers/tests/tokenization_gpt2_test.py index da7028c27d..252dbfe6f4 100644 --- a/pytorch_transformers/tests/tokenization_gpt2_test.py +++ b/pytorch_transformers/tests/tokenization_gpt2_test.py @@ -44,8 +44,9 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): with open(self.merges_file, "w") as fp: fp.write("\n".join(merges)) - def get_tokenizer(self): - return GPT2Tokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map) + def get_tokenizer(self, **kwargs): + kwargs.update(self.special_tokens_map) + return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs) def get_input_output_texts(self): input_text = u"lower newer" diff --git a/pytorch_transformers/tests/tokenization_openai_test.py b/pytorch_transformers/tests/tokenization_openai_test.py index bb354f3fb7..6b86416d2d 100644 --- a/pytorch_transformers/tests/tokenization_openai_test.py +++ b/pytorch_transformers/tests/tokenization_openai_test.py @@ -45,8 +45,8 @@ class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester): with open(self.merges_file, "w") as fp: fp.write("\n".join(merges)) - def get_tokenizer(self): - return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname) + def get_tokenizer(self, **kwargs): + return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname, **kwargs) def get_input_output_texts(self): input_text = u"lower newer" diff --git a/pytorch_transformers/tests/tokenization_roberta_test.py b/pytorch_transformers/tests/tokenization_roberta_test.py index a8f940ae43..5f9b65a7a3 100644 --- a/pytorch_transformers/tests/tokenization_roberta_test.py +++ b/pytorch_transformers/tests/tokenization_roberta_test.py @@ -43,8 +43,9 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): with open(self.merges_file, "w") as fp: fp.write("\n".join(merges)) - def get_tokenizer(self): - return RobertaTokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map) + def get_tokenizer(self, **kwargs): + kwargs.update(self.special_tokens_map) + return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs) def get_input_output_texts(self): input_text = u"lower newer" diff --git a/pytorch_transformers/tests/tokenization_tests_commons.py b/pytorch_transformers/tests/tokenization_tests_commons.py index ebcf6f48d8..779a3ba6c3 100644 --- a/pytorch_transformers/tests/tokenization_tests_commons.py +++ b/pytorch_transformers/tests/tokenization_tests_commons.py @@ -49,14 +49,19 @@ class CommonTestCases: def tearDown(self): shutil.rmtree(self.tmpdirname) - def get_tokenizer(self): + def get_tokenizer(self, **kwargs): raise NotImplementedError def get_input_output_texts(self): raise NotImplementedError def test_save_and_load_tokenizer(self): + # safety check on max_len default value so we are sure the test works tokenizer = self.get_tokenizer() + self.assertNotEqual(tokenizer.max_len, 42) + + # Now let's start the test + tokenizer = self.get_tokenizer(max_len=42) before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") @@ -64,8 +69,12 @@ class CommonTestCases: tokenizer.save_pretrained(tmpdirname) tokenizer = tokenizer.from_pretrained(tmpdirname) - after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") - self.assertListEqual(before_tokens, after_tokens) + after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") + self.assertListEqual(before_tokens, after_tokens) + + self.assertEqual(tokenizer.max_len, 42) + tokenizer = tokenizer.from_pretrained(tmpdirname, max_len=43) + self.assertEqual(tokenizer.max_len, 43) def test_pickle_tokenizer(self): tokenizer = self.get_tokenizer() diff --git a/pytorch_transformers/tests/tokenization_transfo_xl_test.py b/pytorch_transformers/tests/tokenization_transfo_xl_test.py index fbd06cf47e..f881cf1d2b 100644 --- a/pytorch_transformers/tests/tokenization_transfo_xl_test.py +++ b/pytorch_transformers/tests/tokenization_transfo_xl_test.py @@ -37,8 +37,9 @@ class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester): with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) - def get_tokenizer(self): - return TransfoXLTokenizer.from_pretrained(self.tmpdirname, lower_case=True) + def get_tokenizer(self, **kwargs): + kwargs['lower_case'] = True + return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs) def get_input_output_texts(self): input_text = u" UNwanted , running" diff --git a/pytorch_transformers/tests/tokenization_xlm_test.py b/pytorch_transformers/tests/tokenization_xlm_test.py index ede77a1f98..43f1e0c5dd 100644 --- a/pytorch_transformers/tests/tokenization_xlm_test.py +++ b/pytorch_transformers/tests/tokenization_xlm_test.py @@ -44,8 +44,8 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): with open(self.merges_file, "w") as fp: fp.write("\n".join(merges)) - def get_tokenizer(self): - return XLMTokenizer.from_pretrained(self.tmpdirname) + def get_tokenizer(self, **kwargs): + return XLMTokenizer.from_pretrained(self.tmpdirname, **kwargs) def get_input_output_texts(self): input_text = u"lower newer" diff --git a/pytorch_transformers/tests/tokenization_xlnet_test.py b/pytorch_transformers/tests/tokenization_xlnet_test.py index 9feab7c0bd..c603ce55f9 100644 --- a/pytorch_transformers/tests/tokenization_xlnet_test.py +++ b/pytorch_transformers/tests/tokenization_xlnet_test.py @@ -35,8 +35,8 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) tokenizer.save_pretrained(self.tmpdirname) - def get_tokenizer(self): - return XLNetTokenizer.from_pretrained(self.tmpdirname) + def get_tokenizer(self, **kwargs): + return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs) def get_input_output_texts(self): input_text = u"This is a test" diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 51e59fe46c..8d7c29b16c 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -332,7 +332,7 @@ class PreTrainedTokenizer(object): tokenizer_config_file = resolved_vocab_files.pop('tokenizer_config_file', None) if tokenizer_config_file is not None: init_kwargs = json.load(open(tokenizer_config_file, encoding="utf-8")) - saved_init_inputs = init_kwargs.pop('init_inputs', []) + saved_init_inputs = init_kwargs.pop('init_inputs', ()) if not init_inputs: init_inputs = saved_init_inputs else: @@ -399,6 +399,8 @@ class PreTrainedTokenizer(object): tokenizer_config = copy.deepcopy(self.init_kwargs) tokenizer_config['init_inputs'] = copy.deepcopy(self.init_inputs) + for file_id in self.vocab_files_names.keys(): + tokenizer_config.pop(file_id, None) with open(tokenizer_config_file, 'w', encoding='utf-8') as f: f.write(json.dumps(tokenizer_config, ensure_ascii=False))