update tokenizer - update squad example for xlnet

This commit is contained in:
thomwolf
2019-07-15 17:30:42 +02:00
parent 3b469cb422
commit 15d8b1266c
20 changed files with 191 additions and 131 deletions

View File

@@ -38,7 +38,10 @@ class TokenizationTest(unittest.TestCase):
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
create_and_check_tokenizer_commons(self, BertTokenizer, tmpdirname)
input_text = u"UNwant\u00E9d,running"
output_text = u"unwanted, running"
create_and_check_tokenizer_commons(self, input_text, output_text, BertTokenizer, tmpdirname)
tokenizer = BertTokenizer(vocab_file)

View File

@@ -41,7 +41,10 @@ class GPT2TokenizationTest(unittest.TestCase):
with open(merges_file, "w") as fp:
fp.write("\n".join(merges))
create_and_check_tokenizer_commons(self, GPT2Tokenizer, tmpdirname, **special_tokens_map)
input_text = u"lower newer"
output_text = u"lower<unk>newer"
create_and_check_tokenizer_commons(self, input_text, output_text, GPT2Tokenizer, tmpdirname, **special_tokens_map)
tokenizer = GPT2Tokenizer(vocab_file, merges_file, **special_tokens_map)
text = "lower"

View File

@@ -42,7 +42,10 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
with open(merges_file, "w") as fp:
fp.write("\n".join(merges))
create_and_check_tokenizer_commons(self, OpenAIGPTTokenizer, tmpdirname)
input_text = u"lower newer"
output_text = u"lower newer"
create_and_check_tokenizer_commons(self, input_text, output_text, OpenAIGPTTokenizer, tmpdirname)
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file)

View File

@@ -113,23 +113,24 @@ def create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kw
tester.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
def create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs):
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
text = u"He is very happy, UNwant\u00E9d,running"
tokens = tokenizer.tokenize(text)
tokens = tokenizer.tokenize(input_text)
ids = tokenizer.convert_tokens_to_ids(tokens)
ids_2 = tokenizer.encode(text)
ids_2 = tokenizer.encode(input_text)
tester.assertListEqual(ids, ids_2)
tokens_2 = tokenizer.convert_ids_to_tokens(ids)
text_2 = tokenizer.decode(ids)
tester.assertEqual(text_2, output_text)
tester.assertNotEqual(len(tokens_2), 0)
tester.assertIsInstance(text_2, (str, unicode))
def create_and_check_tokenizer_commons(tester, tokenizer_class, *inputs, **kwargs):
create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
def create_and_check_tokenizer_commons(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs):
create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs)
create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs)

View File

@@ -34,7 +34,10 @@ class TransfoXLTokenizationTest(unittest.TestCase):
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
create_and_check_tokenizer_commons(self, TransfoXLTokenizer, tmpdirname, lower_case=True)
input_text = u"<unk> UNwanted , running"
output_text = u"<unk> unwanted, running"
create_and_check_tokenizer_commons(self, input_text, output_text, TransfoXLTokenizer, tmpdirname, lower_case=True)
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)

View File

@@ -41,7 +41,10 @@ class XLMTokenizationTest(unittest.TestCase):
with open(merges_file, "w") as fp:
fp.write("\n".join(merges))
create_and_check_tokenizer_commons(self, XLMTokenizer, tmpdirname)
input_text = u"lower newer"
output_text = u"lower newer"
create_and_check_tokenizer_commons(self, input_text, output_text, XLMTokenizer, tmpdirname)
tokenizer = XLMTokenizer(vocab_file, merges_file)

View File

@@ -32,7 +32,10 @@ class XLNetTokenizationTest(unittest.TestCase):
with TemporaryDirectory() as tmpdirname:
tokenizer.save_pretrained(tmpdirname)
create_and_check_tokenizer_commons(self, XLNetTokenizer, tmpdirname)
input_text = u"This is a test"
output_text = u"This is a test"
create_and_check_tokenizer_commons(self, input_text, output_text, XLNetTokenizer, tmpdirname)
tokens = tokenizer.tokenize(u'This is a test')
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])