update tokenizer - update squad example for xlnet
This commit is contained in:
@@ -38,7 +38,10 @@ class TokenizationTest(unittest.TestCase):
|
||||
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
create_and_check_tokenizer_commons(self, BertTokenizer, tmpdirname)
|
||||
input_text = u"UNwant\u00E9d,running"
|
||||
output_text = u"unwanted, running"
|
||||
|
||||
create_and_check_tokenizer_commons(self, input_text, output_text, BertTokenizer, tmpdirname)
|
||||
|
||||
tokenizer = BertTokenizer(vocab_file)
|
||||
|
||||
|
||||
@@ -41,7 +41,10 @@ class GPT2TokenizationTest(unittest.TestCase):
|
||||
with open(merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
create_and_check_tokenizer_commons(self, GPT2Tokenizer, tmpdirname, **special_tokens_map)
|
||||
input_text = u"lower newer"
|
||||
output_text = u"lower<unk>newer"
|
||||
|
||||
create_and_check_tokenizer_commons(self, input_text, output_text, GPT2Tokenizer, tmpdirname, **special_tokens_map)
|
||||
|
||||
tokenizer = GPT2Tokenizer(vocab_file, merges_file, **special_tokens_map)
|
||||
text = "lower"
|
||||
|
||||
@@ -42,7 +42,10 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
|
||||
with open(merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
create_and_check_tokenizer_commons(self, OpenAIGPTTokenizer, tmpdirname)
|
||||
input_text = u"lower newer"
|
||||
output_text = u"lower newer"
|
||||
|
||||
create_and_check_tokenizer_commons(self, input_text, output_text, OpenAIGPTTokenizer, tmpdirname)
|
||||
|
||||
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file)
|
||||
|
||||
|
||||
@@ -113,23 +113,24 @@ def create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kw
|
||||
tester.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
|
||||
|
||||
|
||||
def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
||||
def create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs):
|
||||
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
|
||||
|
||||
text = u"He is very happy, UNwant\u00E9d,running"
|
||||
tokens = tokenizer.tokenize(text)
|
||||
tokens = tokenizer.tokenize(input_text)
|
||||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
ids_2 = tokenizer.encode(text)
|
||||
ids_2 = tokenizer.encode(input_text)
|
||||
tester.assertListEqual(ids, ids_2)
|
||||
|
||||
tokens_2 = tokenizer.convert_ids_to_tokens(ids)
|
||||
text_2 = tokenizer.decode(ids)
|
||||
|
||||
tester.assertEqual(text_2, output_text)
|
||||
|
||||
tester.assertNotEqual(len(tokens_2), 0)
|
||||
tester.assertIsInstance(text_2, (str, unicode))
|
||||
|
||||
def create_and_check_tokenizer_commons(tester, tokenizer_class, *inputs, **kwargs):
|
||||
create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
def create_and_check_tokenizer_commons(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs):
|
||||
create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs)
|
||||
create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
|
||||
@@ -34,7 +34,10 @@ class TransfoXLTokenizationTest(unittest.TestCase):
|
||||
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
create_and_check_tokenizer_commons(self, TransfoXLTokenizer, tmpdirname, lower_case=True)
|
||||
input_text = u"<unk> UNwanted , running"
|
||||
output_text = u"<unk> unwanted, running"
|
||||
|
||||
create_and_check_tokenizer_commons(self, input_text, output_text, TransfoXLTokenizer, tmpdirname, lower_case=True)
|
||||
|
||||
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)
|
||||
|
||||
|
||||
@@ -41,7 +41,10 @@ class XLMTokenizationTest(unittest.TestCase):
|
||||
with open(merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
create_and_check_tokenizer_commons(self, XLMTokenizer, tmpdirname)
|
||||
input_text = u"lower newer"
|
||||
output_text = u"lower newer"
|
||||
|
||||
create_and_check_tokenizer_commons(self, input_text, output_text, XLMTokenizer, tmpdirname)
|
||||
|
||||
tokenizer = XLMTokenizer(vocab_file, merges_file)
|
||||
|
||||
|
||||
@@ -32,7 +32,10 @@ class XLNetTokenizationTest(unittest.TestCase):
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
tokenizer.save_pretrained(tmpdirname)
|
||||
|
||||
create_and_check_tokenizer_commons(self, XLNetTokenizer, tmpdirname)
|
||||
input_text = u"This is a test"
|
||||
output_text = u"This is a test"
|
||||
|
||||
create_and_check_tokenizer_commons(self, input_text, output_text, XLNetTokenizer, tmpdirname)
|
||||
|
||||
tokens = tokenizer.tokenize(u'This is a test')
|
||||
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
|
||||
|
||||
Reference in New Issue
Block a user