From 870b734bfd2cc83e43b29050fba03709a0c5b539 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 15 Apr 2019 12:03:56 +0200 Subject: [PATCH] added tokenizers serialization tests --- pytorch_pretrained_bert/tokenization.py | 1 + pytorch_pretrained_bert/tokenization_gpt2.py | 6 ++- .../tokenization_openai.py | 6 ++- .../tokenization_transfo_xl.py | 1 + tests/tokenization_openai_test.py | 16 +++++++ tests/tokenization_test.py | 11 +++++ tests/tokenization_transfo_xl_test.py | 42 ++++++------------- 7 files changed, 51 insertions(+), 32 deletions(-) diff --git a/pytorch_pretrained_bert/tokenization.py b/pytorch_pretrained_bert/tokenization.py index 6e2e11ed92..8fd65f55f0 100644 --- a/pytorch_pretrained_bert/tokenization.py +++ b/pytorch_pretrained_bert/tokenization.py @@ -146,6 +146,7 @@ class BertTokenizer(object): index = token_index writer.write(token + u'\n') index += 1 + return vocab_file @classmethod def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): diff --git a/pytorch_pretrained_bert/tokenization_gpt2.py b/pytorch_pretrained_bert/tokenization_gpt2.py index 07db995b96..b49e1310e4 100644 --- a/pytorch_pretrained_bert/tokenization_gpt2.py +++ b/pytorch_pretrained_bert/tokenization_gpt2.py @@ -188,7 +188,10 @@ class GPT2Tokenizer(object): return word def save_vocabulary(self, vocab_path): - """Save the tokenizer vocabulary to a path.""" + """Save the tokenizer vocabulary and merge files to a directory.""" + if not os.path.isdir(vocab_path): + logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) + return vocab_file = os.path.join(vocab_path, VOCAB_NAME) merge_file = os.path.join(vocab_path, MERGES_NAME) json.dump(self.encoder, vocab_file) @@ -202,6 +205,7 @@ class GPT2Tokenizer(object): index = token_index writer.write(bpe_tokens + u'\n') index += 1 + return vocab_file, merge_file def encode(self, text): bpe_tokens = [] diff --git a/pytorch_pretrained_bert/tokenization_openai.py b/pytorch_pretrained_bert/tokenization_openai.py index aa0438ccf8..f3ce7ab251 100644 --- a/pytorch_pretrained_bert/tokenization_openai.py +++ b/pytorch_pretrained_bert/tokenization_openai.py @@ -263,7 +263,10 @@ class OpenAIGPTTokenizer(object): return out_string def save_vocabulary(self, vocab_path): - """Save the tokenizer vocabulary to a path.""" + """Save the tokenizer vocabulary and merge files to a directory.""" + if not os.path.isdir(vocab_path): + logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) + return vocab_file = os.path.join(vocab_path, VOCAB_NAME) merge_file = os.path.join(vocab_path, MERGES_NAME) json.dump(self.encoder, vocab_file) @@ -277,3 +280,4 @@ class OpenAIGPTTokenizer(object): index = token_index writer.write(bpe_tokens + u'\n') index += 1 + return vocab_file, merge_file diff --git a/pytorch_pretrained_bert/tokenization_transfo_xl.py b/pytorch_pretrained_bert/tokenization_transfo_xl.py index b6470c7667..f704a035db 100644 --- a/pytorch_pretrained_bert/tokenization_transfo_xl.py +++ b/pytorch_pretrained_bert/tokenization_transfo_xl.py @@ -148,6 +148,7 @@ class TransfoXLTokenizer(object): index = 0 vocab_file = os.path.join(vocab_path, VOCAB_NAME) torch.save(self.__dict__, vocab_file) + return vocab_file def build_vocab(self): if self.vocab_file: diff --git a/tests/tokenization_openai_test.py b/tests/tokenization_openai_test.py index 6213eb1b03..2b1bdd3a9a 100644 --- a/tests/tokenization_openai_test.py +++ b/tests/tokenization_openai_test.py @@ -52,5 +52,21 @@ class OpenAIGPTTokenizationTest(unittest.TestCase): self.assertListEqual( tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) + vocab_file, merges_file = tokenizer.save_vocabulary(vocab_path="/tmp/") + tokenizer.from_pretrained("/tmp/") + os.remove(vocab_file) + os.remove(merges_file) + + text = "lower" + bpe_tokens = ["low", "er"] + tokens = tokenizer.tokenize(text) + self.assertListEqual(tokens, bpe_tokens) + + input_tokens = tokens + [""] + input_bpe_tokens = [14, 15, 20] + self.assertListEqual( + tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) + + if __name__ == '__main__': unittest.main() diff --git a/tests/tokenization_test.py b/tests/tokenization_test.py index 78e145ffd2..15cc7ccd82 100644 --- a/tests/tokenization_test.py +++ b/tests/tokenization_test.py @@ -46,6 +46,17 @@ class TokenizationTest(unittest.TestCase): self.assertListEqual( tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) + vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/") + tokenizer.from_pretrained(vocab_file) + os.remove(vocab_file) + + tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") + self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) + + self.assertListEqual( + tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) + + def test_chinese(self): tokenizer = BasicTokenizer() diff --git a/tests/tokenization_transfo_xl_test.py b/tests/tokenization_transfo_xl_test.py index 9ff04f5f34..add2eb4e71 100644 --- a/tests/tokenization_transfo_xl_test.py +++ b/tests/tokenization_transfo_xl_test.py @@ -18,9 +18,7 @@ import os import unittest from io import open -from pytorch_pretrained_bert.tokenization_transfo_xl import (TransfoXLTokenizer, - _is_control, _is_punctuation, - _is_whitespace) +from pytorch_pretrained_bert.tokenization_transfo_xl import TransfoXLTokenizer class TransfoXLTokenizationTest(unittest.TestCase): @@ -43,6 +41,17 @@ class TransfoXLTokenizationTest(unittest.TestCase): self.assertListEqual( tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7]) + vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/") + tokenizer.from_pretrained(vocab_file) + os.remove(vocab_file) + + tokens = tokenizer.tokenize(u" UNwant\u00E9d,running") + self.assertListEqual(tokens, ["", "unwanted", ",", "running"]) + + self.assertListEqual( + tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7]) + + def test_full_tokenizer_lower(self): tokenizer = TransfoXLTokenizer(lower_case=True) @@ -58,33 +67,6 @@ class TransfoXLTokenizationTest(unittest.TestCase): tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), ["HeLLo", "!", "how", "Are", "yoU", "?"]) - def test_is_whitespace(self): - self.assertTrue(_is_whitespace(u" ")) - self.assertTrue(_is_whitespace(u"\t")) - self.assertTrue(_is_whitespace(u"\r")) - self.assertTrue(_is_whitespace(u"\n")) - self.assertTrue(_is_whitespace(u"\u00A0")) - - self.assertFalse(_is_whitespace(u"A")) - self.assertFalse(_is_whitespace(u"-")) - - def test_is_control(self): - self.assertTrue(_is_control(u"\u0005")) - - self.assertFalse(_is_control(u"A")) - self.assertFalse(_is_control(u" ")) - self.assertFalse(_is_control(u"\t")) - self.assertFalse(_is_control(u"\r")) - - def test_is_punctuation(self): - self.assertTrue(_is_punctuation(u"-")) - self.assertTrue(_is_punctuation(u"$")) - self.assertTrue(_is_punctuation(u"`")) - self.assertTrue(_is_punctuation(u".")) - - self.assertFalse(_is_punctuation(u"A")) - self.assertFalse(_is_punctuation(u" ")) - if __name__ == '__main__': unittest.main()