added tokenizers serialization tests
This commit is contained in:
@@ -146,6 +146,7 @@ class BertTokenizer(object):
|
||||
index = token_index
|
||||
writer.write(token + u'\n')
|
||||
index += 1
|
||||
return vocab_file
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||
|
||||
@@ -188,7 +188,10 @@ class GPT2Tokenizer(object):
|
||||
return word
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
"""Save the tokenizer vocabulary to a path."""
|
||||
"""Save the tokenizer vocabulary and merge files to a directory."""
|
||||
if not os.path.isdir(vocab_path):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
|
||||
return
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||
merge_file = os.path.join(vocab_path, MERGES_NAME)
|
||||
json.dump(self.encoder, vocab_file)
|
||||
@@ -202,6 +205,7 @@ class GPT2Tokenizer(object):
|
||||
index = token_index
|
||||
writer.write(bpe_tokens + u'\n')
|
||||
index += 1
|
||||
return vocab_file, merge_file
|
||||
|
||||
def encode(self, text):
|
||||
bpe_tokens = []
|
||||
|
||||
@@ -263,7 +263,10 @@ class OpenAIGPTTokenizer(object):
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
"""Save the tokenizer vocabulary to a path."""
|
||||
"""Save the tokenizer vocabulary and merge files to a directory."""
|
||||
if not os.path.isdir(vocab_path):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
|
||||
return
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||
merge_file = os.path.join(vocab_path, MERGES_NAME)
|
||||
json.dump(self.encoder, vocab_file)
|
||||
@@ -277,3 +280,4 @@ class OpenAIGPTTokenizer(object):
|
||||
index = token_index
|
||||
writer.write(bpe_tokens + u'\n')
|
||||
index += 1
|
||||
return vocab_file, merge_file
|
||||
|
||||
@@ -148,6 +148,7 @@ class TransfoXLTokenizer(object):
|
||||
index = 0
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||
torch.save(self.__dict__, vocab_file)
|
||||
return vocab_file
|
||||
|
||||
def build_vocab(self):
|
||||
if self.vocab_file:
|
||||
|
||||
@@ -52,5 +52,21 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
vocab_file, merges_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
|
||||
tokenizer.from_pretrained("/tmp/")
|
||||
os.remove(vocab_file)
|
||||
os.remove(merges_file)
|
||||
|
||||
text = "lower"
|
||||
bpe_tokens = ["low", "er</w>"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
|
||||
input_tokens = tokens + ["<unk>"]
|
||||
input_bpe_tokens = [14, 15, 20]
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -46,6 +46,17 @@ class TokenizationTest(unittest.TestCase):
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||
|
||||
vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
|
||||
tokenizer.from_pretrained(vocab_file)
|
||||
os.remove(vocab_file)
|
||||
|
||||
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
|
||||
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||
|
||||
|
||||
def test_chinese(self):
|
||||
tokenizer = BasicTokenizer()
|
||||
|
||||
|
||||
@@ -18,9 +18,7 @@ import os
|
||||
import unittest
|
||||
from io import open
|
||||
|
||||
from pytorch_pretrained_bert.tokenization_transfo_xl import (TransfoXLTokenizer,
|
||||
_is_control, _is_punctuation,
|
||||
_is_whitespace)
|
||||
from pytorch_pretrained_bert.tokenization_transfo_xl import TransfoXLTokenizer
|
||||
|
||||
|
||||
class TransfoXLTokenizationTest(unittest.TestCase):
|
||||
@@ -43,6 +41,17 @@ class TransfoXLTokenizationTest(unittest.TestCase):
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
|
||||
|
||||
vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
|
||||
tokenizer.from_pretrained(vocab_file)
|
||||
os.remove(vocab_file)
|
||||
|
||||
tokens = tokenizer.tokenize(u"<unk> UNwant\u00E9d,running")
|
||||
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
|
||||
|
||||
|
||||
def test_full_tokenizer_lower(self):
|
||||
tokenizer = TransfoXLTokenizer(lower_case=True)
|
||||
|
||||
@@ -58,33 +67,6 @@ class TransfoXLTokenizationTest(unittest.TestCase):
|
||||
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
||||
["HeLLo", "!", "how", "Are", "yoU", "?"])
|
||||
|
||||
def test_is_whitespace(self):
|
||||
self.assertTrue(_is_whitespace(u" "))
|
||||
self.assertTrue(_is_whitespace(u"\t"))
|
||||
self.assertTrue(_is_whitespace(u"\r"))
|
||||
self.assertTrue(_is_whitespace(u"\n"))
|
||||
self.assertTrue(_is_whitespace(u"\u00A0"))
|
||||
|
||||
self.assertFalse(_is_whitespace(u"A"))
|
||||
self.assertFalse(_is_whitespace(u"-"))
|
||||
|
||||
def test_is_control(self):
|
||||
self.assertTrue(_is_control(u"\u0005"))
|
||||
|
||||
self.assertFalse(_is_control(u"A"))
|
||||
self.assertFalse(_is_control(u" "))
|
||||
self.assertFalse(_is_control(u"\t"))
|
||||
self.assertFalse(_is_control(u"\r"))
|
||||
|
||||
def test_is_punctuation(self):
|
||||
self.assertTrue(_is_punctuation(u"-"))
|
||||
self.assertTrue(_is_punctuation(u"$"))
|
||||
self.assertTrue(_is_punctuation(u"`"))
|
||||
self.assertTrue(_is_punctuation(u"."))
|
||||
|
||||
self.assertFalse(_is_punctuation(u"A"))
|
||||
self.assertFalse(_is_punctuation(u" "))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user