added test and debug tokenizer configuration serialization
This commit is contained in:
@@ -41,8 +41,8 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
def get_tokenizer(self):
|
||||
return BertTokenizer.from_pretrained(self.tmpdirname)
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = u"UNwant\u00E9d,running"
|
||||
|
||||
@@ -44,8 +44,9 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
with open(self.merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_tokenizer(self):
|
||||
return GPT2Tokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map)
|
||||
def get_tokenizer(self, **kwargs):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = u"lower newer"
|
||||
|
||||
@@ -45,8 +45,8 @@ class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
with open(self.merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_tokenizer(self):
|
||||
return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname)
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = u"lower newer"
|
||||
|
||||
@@ -43,8 +43,9 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
with open(self.merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_tokenizer(self):
|
||||
return RobertaTokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map)
|
||||
def get_tokenizer(self, **kwargs):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = u"lower newer"
|
||||
|
||||
@@ -49,14 +49,19 @@ class CommonTestCases:
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self):
|
||||
def get_tokenizer(self, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_input_output_texts(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_save_and_load_tokenizer(self):
|
||||
# safety check on max_len default value so we are sure the test works
|
||||
tokenizer = self.get_tokenizer()
|
||||
self.assertNotEqual(tokenizer.max_len, 42)
|
||||
|
||||
# Now let's start the test
|
||||
tokenizer = self.get_tokenizer(max_len=42)
|
||||
|
||||
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
||||
|
||||
@@ -64,8 +69,12 @@ class CommonTestCases:
|
||||
tokenizer.save_pretrained(tmpdirname)
|
||||
tokenizer = tokenizer.from_pretrained(tmpdirname)
|
||||
|
||||
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
||||
self.assertListEqual(before_tokens, after_tokens)
|
||||
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
||||
self.assertListEqual(before_tokens, after_tokens)
|
||||
|
||||
self.assertEqual(tokenizer.max_len, 42)
|
||||
tokenizer = tokenizer.from_pretrained(tmpdirname, max_len=43)
|
||||
self.assertEqual(tokenizer.max_len, 43)
|
||||
|
||||
def test_pickle_tokenizer(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
@@ -37,8 +37,9 @@ class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
def get_tokenizer(self):
|
||||
return TransfoXLTokenizer.from_pretrained(self.tmpdirname, lower_case=True)
|
||||
def get_tokenizer(self, **kwargs):
|
||||
kwargs['lower_case'] = True
|
||||
return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = u"<unk> UNwanted , running"
|
||||
|
||||
@@ -44,8 +44,8 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
with open(self.merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_tokenizer(self):
|
||||
return XLMTokenizer.from_pretrained(self.tmpdirname)
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return XLMTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = u"lower newer"
|
||||
|
||||
@@ -35,8 +35,8 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self):
|
||||
return XLNetTokenizer.from_pretrained(self.tmpdirname)
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = u"This is a test"
|
||||
|
||||
@@ -332,7 +332,7 @@ class PreTrainedTokenizer(object):
|
||||
tokenizer_config_file = resolved_vocab_files.pop('tokenizer_config_file', None)
|
||||
if tokenizer_config_file is not None:
|
||||
init_kwargs = json.load(open(tokenizer_config_file, encoding="utf-8"))
|
||||
saved_init_inputs = init_kwargs.pop('init_inputs', [])
|
||||
saved_init_inputs = init_kwargs.pop('init_inputs', ())
|
||||
if not init_inputs:
|
||||
init_inputs = saved_init_inputs
|
||||
else:
|
||||
@@ -399,6 +399,8 @@ class PreTrainedTokenizer(object):
|
||||
|
||||
tokenizer_config = copy.deepcopy(self.init_kwargs)
|
||||
tokenizer_config['init_inputs'] = copy.deepcopy(self.init_inputs)
|
||||
for file_id in self.vocab_files_names.keys():
|
||||
tokenizer_config.pop(file_id, None)
|
||||
|
||||
with open(tokenizer_config_file, 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(tokenizer_config, ensure_ascii=False))
|
||||
|
||||
Reference in New Issue
Block a user