added tokenizers serialization tests
This commit is contained in:
@@ -146,6 +146,7 @@ class BertTokenizer(object):
|
|||||||
index = token_index
|
index = token_index
|
||||||
writer.write(token + u'\n')
|
writer.write(token + u'\n')
|
||||||
index += 1
|
index += 1
|
||||||
|
return vocab_file
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||||
|
|||||||
@@ -188,7 +188,10 @@ class GPT2Tokenizer(object):
|
|||||||
return word
|
return word
|
||||||
|
|
||||||
def save_vocabulary(self, vocab_path):
|
def save_vocabulary(self, vocab_path):
|
||||||
"""Save the tokenizer vocabulary to a path."""
|
"""Save the tokenizer vocabulary and merge files to a directory."""
|
||||||
|
if not os.path.isdir(vocab_path):
|
||||||
|
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
|
||||||
|
return
|
||||||
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||||
merge_file = os.path.join(vocab_path, MERGES_NAME)
|
merge_file = os.path.join(vocab_path, MERGES_NAME)
|
||||||
json.dump(self.encoder, vocab_file)
|
json.dump(self.encoder, vocab_file)
|
||||||
@@ -202,6 +205,7 @@ class GPT2Tokenizer(object):
|
|||||||
index = token_index
|
index = token_index
|
||||||
writer.write(bpe_tokens + u'\n')
|
writer.write(bpe_tokens + u'\n')
|
||||||
index += 1
|
index += 1
|
||||||
|
return vocab_file, merge_file
|
||||||
|
|
||||||
def encode(self, text):
|
def encode(self, text):
|
||||||
bpe_tokens = []
|
bpe_tokens = []
|
||||||
|
|||||||
@@ -263,7 +263,10 @@ class OpenAIGPTTokenizer(object):
|
|||||||
return out_string
|
return out_string
|
||||||
|
|
||||||
def save_vocabulary(self, vocab_path):
|
def save_vocabulary(self, vocab_path):
|
||||||
"""Save the tokenizer vocabulary to a path."""
|
"""Save the tokenizer vocabulary and merge files to a directory."""
|
||||||
|
if not os.path.isdir(vocab_path):
|
||||||
|
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
|
||||||
|
return
|
||||||
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||||
merge_file = os.path.join(vocab_path, MERGES_NAME)
|
merge_file = os.path.join(vocab_path, MERGES_NAME)
|
||||||
json.dump(self.encoder, vocab_file)
|
json.dump(self.encoder, vocab_file)
|
||||||
@@ -277,3 +280,4 @@ class OpenAIGPTTokenizer(object):
|
|||||||
index = token_index
|
index = token_index
|
||||||
writer.write(bpe_tokens + u'\n')
|
writer.write(bpe_tokens + u'\n')
|
||||||
index += 1
|
index += 1
|
||||||
|
return vocab_file, merge_file
|
||||||
|
|||||||
@@ -148,6 +148,7 @@ class TransfoXLTokenizer(object):
|
|||||||
index = 0
|
index = 0
|
||||||
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||||
torch.save(self.__dict__, vocab_file)
|
torch.save(self.__dict__, vocab_file)
|
||||||
|
return vocab_file
|
||||||
|
|
||||||
def build_vocab(self):
|
def build_vocab(self):
|
||||||
if self.vocab_file:
|
if self.vocab_file:
|
||||||
|
|||||||
@@ -52,5 +52,21 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
|
|||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||||
|
|
||||||
|
vocab_file, merges_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
|
||||||
|
tokenizer.from_pretrained("/tmp/")
|
||||||
|
os.remove(vocab_file)
|
||||||
|
os.remove(merges_file)
|
||||||
|
|
||||||
|
text = "lower"
|
||||||
|
bpe_tokens = ["low", "er</w>"]
|
||||||
|
tokens = tokenizer.tokenize(text)
|
||||||
|
self.assertListEqual(tokens, bpe_tokens)
|
||||||
|
|
||||||
|
input_tokens = tokens + ["<unk>"]
|
||||||
|
input_bpe_tokens = [14, 15, 20]
|
||||||
|
self.assertListEqual(
|
||||||
|
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -46,6 +46,17 @@ class TokenizationTest(unittest.TestCase):
|
|||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||||
|
|
||||||
|
vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
|
||||||
|
tokenizer.from_pretrained(vocab_file)
|
||||||
|
os.remove(vocab_file)
|
||||||
|
|
||||||
|
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
|
||||||
|
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||||
|
|
||||||
|
|
||||||
def test_chinese(self):
|
def test_chinese(self):
|
||||||
tokenizer = BasicTokenizer()
|
tokenizer = BasicTokenizer()
|
||||||
|
|
||||||
|
|||||||
@@ -18,9 +18,7 @@ import os
|
|||||||
import unittest
|
import unittest
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
from pytorch_pretrained_bert.tokenization_transfo_xl import (TransfoXLTokenizer,
|
from pytorch_pretrained_bert.tokenization_transfo_xl import TransfoXLTokenizer
|
||||||
_is_control, _is_punctuation,
|
|
||||||
_is_whitespace)
|
|
||||||
|
|
||||||
|
|
||||||
class TransfoXLTokenizationTest(unittest.TestCase):
|
class TransfoXLTokenizationTest(unittest.TestCase):
|
||||||
@@ -43,6 +41,17 @@ class TransfoXLTokenizationTest(unittest.TestCase):
|
|||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
|
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
|
||||||
|
|
||||||
|
vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
|
||||||
|
tokenizer.from_pretrained(vocab_file)
|
||||||
|
os.remove(vocab_file)
|
||||||
|
|
||||||
|
tokens = tokenizer.tokenize(u"<unk> UNwant\u00E9d,running")
|
||||||
|
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
|
||||||
|
|
||||||
|
|
||||||
def test_full_tokenizer_lower(self):
|
def test_full_tokenizer_lower(self):
|
||||||
tokenizer = TransfoXLTokenizer(lower_case=True)
|
tokenizer = TransfoXLTokenizer(lower_case=True)
|
||||||
|
|
||||||
@@ -58,33 +67,6 @@ class TransfoXLTokenizationTest(unittest.TestCase):
|
|||||||
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
||||||
["HeLLo", "!", "how", "Are", "yoU", "?"])
|
["HeLLo", "!", "how", "Are", "yoU", "?"])
|
||||||
|
|
||||||
def test_is_whitespace(self):
|
|
||||||
self.assertTrue(_is_whitespace(u" "))
|
|
||||||
self.assertTrue(_is_whitespace(u"\t"))
|
|
||||||
self.assertTrue(_is_whitespace(u"\r"))
|
|
||||||
self.assertTrue(_is_whitespace(u"\n"))
|
|
||||||
self.assertTrue(_is_whitespace(u"\u00A0"))
|
|
||||||
|
|
||||||
self.assertFalse(_is_whitespace(u"A"))
|
|
||||||
self.assertFalse(_is_whitespace(u"-"))
|
|
||||||
|
|
||||||
def test_is_control(self):
|
|
||||||
self.assertTrue(_is_control(u"\u0005"))
|
|
||||||
|
|
||||||
self.assertFalse(_is_control(u"A"))
|
|
||||||
self.assertFalse(_is_control(u" "))
|
|
||||||
self.assertFalse(_is_control(u"\t"))
|
|
||||||
self.assertFalse(_is_control(u"\r"))
|
|
||||||
|
|
||||||
def test_is_punctuation(self):
|
|
||||||
self.assertTrue(_is_punctuation(u"-"))
|
|
||||||
self.assertTrue(_is_punctuation(u"$"))
|
|
||||||
self.assertTrue(_is_punctuation(u"`"))
|
|
||||||
self.assertTrue(_is_punctuation(u"."))
|
|
||||||
|
|
||||||
self.assertFalse(_is_punctuation(u"A"))
|
|
||||||
self.assertFalse(_is_punctuation(u" "))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user